mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68368 Currently, each instance of `StaticRuntime` has its own copy of `std::function` object wrapped in `ProcessedNode::Function` object, in order to invoke actual operation implementation. However, all instances of `StaticRuntime` derived from same `StaticModule` objects invoke exactly same op implementation, and this is avoidable. This change adds `StaticModule::functions_` member variable to keep a list of unique instance of `ProcessedFunction` objects. A newly constructed `StaticRuntime` takes `ProcessedFunction`'s pointers instead of the whole function object. This can save a substantial amount of memory per `StaticRuntime` instance. This comes with a sacrifice in execution time. Now that a `ProcessedNode` instance keeps the function object's pointer, executing a node now involves an extra pointer dereference. However, this cost was proved to be negligible from local performance tests. Thanks to hlu1 for proposing this non-intrusive improvement idea :D Test Plan: This change reduces the size of a StaticRuntime instance by 14.41% (459KB -> 393KB) (patched D32181666 to print the memory turnover from instantiating a StaticRuntime instance) for CMF/local ( & 8% for CMF/local_ro). No noticeable latency regression was observed. ==AFTER * CMF/local memory turnover: 393608 latency: PyTorch run finished. Milliseconds per iter: 15.6965. Iters per second: 63.7087 * CMF/local_ro memory turnover:387288 latency: PyTorch run finished. Milliseconds per iter: 7.51308. Iters per second: 133.101 ==BEFORE * CMF/local memory turnover: 459888 latency: PyTorch run finished. Milliseconds per iter: 15.8278. Iters per second: 63.18 * CMF/local_ro memory turnover: 420832 latenfcy: PyTorch run finished. Milliseconds per iter: 7.43756. Iters per second: 134.453 ==Confirmation that ptvsc2_predictor_bench reports the same memrmoy management stats for inline_cvr: ==AFTER Total number of managed tensors: 2660 Total number of managed output tensors: 0 Total number of unmanaged values: 3041 Total memory managed: 1496896 bytes Total number of reused tensors: 1183 Total number of 'out' variant nodes/total number of nodes: 2452/2469 (99.3115%) Total number of managed tensors: 1412 Total number of managed output tensors: 0 Total number of unmanaged values: 2677 Total memory managed: 39040 bytes Total number of reused tensors: 959 Total number of 'out' variant nodes/total number of nodes: 1928/1937 (99.5354%) Total number of managed tensors: 1293 Total number of managed output tensors: 0 Total number of unmanaged values: 14 Total memory managed: 5293824 bytes Total number of reused tensors: 771 Total number of 'out' variant nodes/total number of nodes: 1298/1298 (100%) ==BEFORE Total number of managed tensors: 2660 Total number of managed output tensors: 0 Total number of unmanaged values: 3041 Total memory managed: 1496896 bytes Total number of reused tensors: 1183 Total number of 'out' variant nodes/total number of nodes: 2452/2469 (99.3115%) Total number of managed tensors: 1412 Total number of managed output tensors: 0 Total number of unmanaged values: 2677 Total memory managed: 39040 bytes Total number of reused tensors: 959 Total number of 'out' variant nodes/total number of nodes: 1928/1937 (99.5354%) Total number of managed tensors: 1293 Total number of managed output tensors: 0 Total number of unmanaged values: 14 Total memory managed: 5293824 bytes Total number of reused tensors: 771 Total number of 'out' variant nodes/total number of nodes: 1298/1298 (100%) Reviewed By: swolchok Differential Revision: D32337548 fbshipit-source-id: e714e735399c93fde337b0f70e203a2de632057a
986 lines
33 KiB
C++
986 lines
33 KiB
C++
#include <gtest/gtest.h>
|
|
#include <torch/csrc/jit/ir/alias_analysis.h>
|
|
#include <torch/csrc/jit/ir/irparser.h>
|
|
#include <torch/csrc/jit/runtime/static/ProcessedNodeInputs.h>
|
|
#include <torch/csrc/jit/runtime/static/fusion.h>
|
|
#include <torch/csrc/jit/runtime/static/impl.h>
|
|
#include <torch/csrc/jit/runtime/static/ops.h>
|
|
#include <torch/csrc/jit/runtime/static/passes.h>
|
|
|
|
#include "deep_wide_pt.h"
|
|
#include "test_utils.h"
|
|
|
|
using namespace torch;
|
|
using namespace torch::jit;
|
|
using namespace torch::jit::test;
|
|
|
|
namespace {
|
|
|
|
StaticModule makeStaticModuleFromScript(const std::string& script) {
|
|
Module m("module");
|
|
m.define(script);
|
|
return StaticModule(m);
|
|
}
|
|
|
|
bool testCanEnableStaticRuntime(const std::string& jit_script) {
|
|
script::Module module("module");
|
|
module.define(jit_script);
|
|
|
|
Method method = module.get_method("forward");
|
|
auto graph = module.get_method("forward").graph();
|
|
|
|
// here we do not freeze graph
|
|
return canEnableStaticRuntime(graph);
|
|
}
|
|
|
|
bool testHasInplaceOp(const std::string& jit_script) {
|
|
script::Module module("module");
|
|
module.define(jit_script);
|
|
|
|
Method method = module.get_method("forward");
|
|
auto graph = module.get_method("forward").graph();
|
|
|
|
AliasDb alias_db(graph);
|
|
return HasInplaceOp(graph, alias_db);
|
|
}
|
|
|
|
bool testModuleHasOp(const std::string& jit_script, const char* op_name) {
|
|
script::Module module("module");
|
|
module.define(jit_script);
|
|
|
|
return forwardHasOp(module, op_name);
|
|
}
|
|
|
|
const auto reshape_inplace_script = R"JIT(
|
|
def forward(self, inp: Tensor, shape: List[int]):
|
|
a = inp + inp
|
|
b = a.reshape(shape)
|
|
c = b.sigmoid_()
|
|
d = c + c
|
|
e = a + a
|
|
f = b + b
|
|
return (d, e, f)
|
|
)JIT";
|
|
|
|
const auto reshape_inplace_script_1 = R"JIT(
|
|
def forward(self, inp: Tensor, shape: List[int], flag: bool):
|
|
if flag:
|
|
a = inp + inp
|
|
b = a.reshape(shape)
|
|
c = b.sigmoid()
|
|
else:
|
|
a = inp * inp
|
|
b = a.sigmoid_()
|
|
c = b.reshape(shape)
|
|
d = c + c
|
|
e = a + a
|
|
f = b + b
|
|
return (d, e, f)
|
|
)JIT";
|
|
|
|
const auto sigmoid_inplace_script = R"JIT(
|
|
def forward(self, inp: Tensor):
|
|
a = torch.sigmoid(inp, out=inp).clone()
|
|
return (a)
|
|
)JIT";
|
|
|
|
const auto sigmoid_out_script = R"JIT(
|
|
def forward(self, inp: Tensor):
|
|
a = inp + inp
|
|
b = torch.sigmoid(inp, out=a).clone()
|
|
return (b)
|
|
)JIT";
|
|
|
|
} // namespace
|
|
|
|
// Test that StaticModule::value_group groups values of the graph into
|
|
// 1) Inputs/Constants and their aliases 2) Outputs and their aliases.
|
|
TEST(StaticModule, ValueGroup) {
|
|
const std::string src = R"IR(
|
|
graph(%input0 : Tensor, %input1 : Tensor):
|
|
# Constants.
|
|
%0 : int = prim::Constant[value=1]()
|
|
# Internal values.
|
|
%1 : Tensor = aten::add(%input0, %input1, %0)
|
|
# This includes aliases of output.
|
|
%2 : Tensor = aten::add(%input0, %1, %0)
|
|
# This includes output.
|
|
%3 : (Tensor) = prim::TupleConstruct(%2)
|
|
return (%3)
|
|
)IR";
|
|
auto input_graph = std::make_shared<torch::jit::Graph>();
|
|
torch::jit::parseIR(src, input_graph.get());
|
|
torch::jit::StaticModule sm(input_graph);
|
|
const Graph& graph = sm.graph();
|
|
std::vector<const Node*> nodes(graph.nodes().begin(), graph.nodes().end());
|
|
const auto& value_group = sm.value_group();
|
|
|
|
std::vector<const Value*> expected_input_aliases{
|
|
graph.inputs()[0], graph.inputs()[1], nodes[0]->output()};
|
|
for (auto* value : expected_input_aliases) {
|
|
EXPECT_TRUE(value_group.isExternalAlias(value));
|
|
}
|
|
|
|
std::vector<const Value*> expected_output_aliases{
|
|
graph.outputs()[0], nodes[2]->output()};
|
|
for (auto* value : expected_output_aliases) {
|
|
EXPECT_TRUE(value_group.isOutputAlias(value));
|
|
}
|
|
EXPECT_FALSE(value_group.isAlwaysAlive(nodes[1]->output()));
|
|
EXPECT_TRUE(value_group.isAlwaysAlive(graph.inputs()[0]));
|
|
EXPECT_TRUE(value_group.isAlwaysAlive(graph.inputs()[1]));
|
|
EXPECT_TRUE(value_group.isAlwaysAlive(graph.outputs()[0]));
|
|
}
|
|
|
|
TEST(StaticModule, IsOptimizableContainerType_NonOptimizableInputs) {
|
|
// Cannot use out variants for list/tuple construction here because
|
|
// inputs are not produced by nodes with out variants.
|
|
const std::string src = R"JIT(
|
|
def forward(self, a, b):
|
|
a_alias = a.view(a.size())
|
|
non_optimizable_list = [a_alias]
|
|
non_optimizable_tuple = (b, )
|
|
return non_optimizable_list, non_optimizable_tuple
|
|
)JIT";
|
|
|
|
auto sm = makeStaticModuleFromScript(src);
|
|
const auto& graph = sm.graph();
|
|
|
|
for (const Node* n : graph.nodes()) {
|
|
EXPECT_FALSE(sm.is_optimizable_container_type(n));
|
|
}
|
|
}
|
|
|
|
TEST(StaticModule, IsOptimizableContainerType_WrongType) {
|
|
// Cannot use out variants for list/tuple construction here because
|
|
// types are not Tensors
|
|
const std::string src = R"JIT(
|
|
def forward(self, x: int, y: int):
|
|
a = 1 + x
|
|
b = 2 + y
|
|
non_optimizable_list = [a]
|
|
non_optimizable_tuple = (b, )
|
|
return non_optimizable_list, non_optimizable_tuple
|
|
)JIT";
|
|
|
|
auto sm = makeStaticModuleFromScript(src);
|
|
const auto& graph = sm.graph();
|
|
|
|
for (const Node* n : graph.nodes()) {
|
|
EXPECT_FALSE(sm.is_optimizable_container_type(n));
|
|
}
|
|
}
|
|
|
|
TEST(StaticModule, IsOptimizableContainerType_CanUseOutVariant) {
|
|
// This container should be optimizable since aten::add has an
|
|
// out variant the container contains Tensors.
|
|
const std::string src = R"JIT(
|
|
def forward(self, x):
|
|
a = torch.relu(x)
|
|
optimizable_list = [a]
|
|
return optimizable_list
|
|
)JIT";
|
|
auto sm = makeStaticModuleFromScript(src);
|
|
const auto& graph = sm.graph();
|
|
|
|
for (const Node* n : graph.nodes()) {
|
|
if (n->kind() == c10::prim::ListConstruct) {
|
|
EXPECT_TRUE(sm.is_optimizable_container_type(n));
|
|
} else {
|
|
EXPECT_FALSE(sm.is_optimizable_container_type(n));
|
|
}
|
|
}
|
|
}
|
|
|
|
// Test operator() with rvalue inputs
|
|
TEST(StaticModule, RValueInputs) {
|
|
const std::string src = R"JIT(
|
|
def forward(self, x):
|
|
y = torch.relu(x)
|
|
return y.clone()
|
|
)JIT";
|
|
auto sm = makeStaticModuleFromScript(src);
|
|
|
|
std::vector<IValue> input{at::randn({1})};
|
|
|
|
auto expected = sm(input, {});
|
|
auto actual = sm(std::move(input), {});
|
|
|
|
EXPECT_TRUE(expected.isTensor());
|
|
EXPECT_TRUE(actual.isTensor());
|
|
EXPECT_TRUE(expected.toTensor().equal(actual.toTensor()));
|
|
}
|
|
|
|
TEST(StaticRuntime, InPlace) {
|
|
EXPECT_TRUE(testHasInplaceOp(reshape_inplace_script));
|
|
EXPECT_TRUE(testHasInplaceOp(reshape_inplace_script_1));
|
|
EXPECT_TRUE(testHasInplaceOp(sigmoid_inplace_script));
|
|
EXPECT_FALSE(testHasInplaceOp(sigmoid_out_script));
|
|
}
|
|
|
|
TEST(StaticRuntime, ModuleHasOp) {
|
|
EXPECT_TRUE(testModuleHasOp(reshape_inplace_script, "aten::sigmoid_"));
|
|
EXPECT_TRUE(testModuleHasOp(reshape_inplace_script_1, "aten::reshape"));
|
|
EXPECT_TRUE(testModuleHasOp(sigmoid_inplace_script, "aten::clone"));
|
|
EXPECT_FALSE(testModuleHasOp(reshape_inplace_script_1, "aten::add_"));
|
|
}
|
|
|
|
TEST(StaticRuntime, CanEnableStaticRuntime) {
|
|
const auto while_script = R"JIT(
|
|
def forward(self, a: Tensor, x: int):
|
|
c = 0
|
|
while c < x:
|
|
a = a * a
|
|
c += 2
|
|
return a
|
|
)JIT";
|
|
|
|
const auto for_script = R"JIT(
|
|
def forward(self, a: Tensor, x: int):
|
|
for c in range(x):
|
|
a = a * a
|
|
return a
|
|
)JIT";
|
|
|
|
const auto if_script = R"JIT(
|
|
def forward(self, a: Tensor, b: bool):
|
|
if b:
|
|
return a
|
|
else:
|
|
return a * a
|
|
)JIT";
|
|
|
|
const auto is_script = R"JIT(
|
|
def forward(self, a: Tensor, b: Tensor):
|
|
return a is b
|
|
)JIT";
|
|
|
|
const auto is_not_script = R"JIT(
|
|
def forward(self, a: Tensor, b: Tensor):
|
|
return a is not b
|
|
)JIT";
|
|
|
|
EXPECT_TRUE(testCanEnableStaticRuntime(reshape_inplace_script));
|
|
EXPECT_FALSE(testCanEnableStaticRuntime(for_script));
|
|
EXPECT_FALSE(testCanEnableStaticRuntime(while_script));
|
|
EXPECT_FALSE(testCanEnableStaticRuntime(if_script));
|
|
EXPECT_FALSE(testCanEnableStaticRuntime(is_script));
|
|
EXPECT_FALSE(testCanEnableStaticRuntime(is_not_script));
|
|
}
|
|
|
|
TEST(StaticRuntime, NestedOutput) {
|
|
// dict of tuple of list
|
|
const auto nested_output_script_0 = R"JIT(
|
|
def forward(self, a, b):
|
|
c = (a + b).relu().nan_to_num().float()
|
|
d = a.flatten().nan_to_num() * b.flatten().nan_to_num()
|
|
e = d.float().relu()
|
|
f = ([c], [d])
|
|
g = ([e], [f])
|
|
return ({"prediction":(f, g)})
|
|
)JIT";
|
|
|
|
// tuple of lists
|
|
const auto nested_output_script_1 = R"JIT(
|
|
def forward(self, a, b):
|
|
c = (a + b).relu().nan_to_num().float()
|
|
d = a.flatten().nan_to_num() * b.flatten().nan_to_num()
|
|
e = d.float().relu()
|
|
f = [c]
|
|
g = [e]
|
|
return (f, g)
|
|
)JIT";
|
|
|
|
// list of tuple of dict
|
|
const auto nested_output_script_2 = R"JIT(
|
|
def forward(self, a, b):
|
|
c = (a + b).relu().nan_to_num().float()
|
|
d = b * c
|
|
e = a.flatten().nan_to_num() * b.flatten().nan_to_num()
|
|
f = e.float().relu()
|
|
g = ({"d": d}, {"b": b})
|
|
h = ({"e": e}, {"f": f})
|
|
return [g, h]
|
|
)JIT";
|
|
|
|
// lit of dict
|
|
const auto nested_output_script_3 = R"JIT(
|
|
def forward(self, a, b):
|
|
c = (a + b).relu().nan_to_num().float()
|
|
d = b * c
|
|
e = a.flatten().nan_to_num() * b.flatten().nan_to_num()
|
|
f = e.float().relu()
|
|
g = {"d": d, "b": b}
|
|
h = {"e": e, "f": f}
|
|
return [g, h]
|
|
)JIT";
|
|
|
|
auto run_test = [&](std::vector<int64_t> shapes) {
|
|
auto a = at::randn(shapes);
|
|
auto b = at::randn(shapes);
|
|
|
|
std::vector<IValue> args{a, b};
|
|
testStaticRuntime(nested_output_script_0, args);
|
|
testStaticRuntime(nested_output_script_1, args);
|
|
testStaticRuntime(nested_output_script_2, args);
|
|
testStaticRuntime(nested_output_script_3, args);
|
|
|
|
if (shapes.size() > 0 && shapes[0] != 0) {
|
|
shapes[0] *= 3;
|
|
testStaticRuntime(
|
|
nested_output_script_0, args, {at::randn(shapes), at::randn(shapes)});
|
|
testStaticRuntime(
|
|
nested_output_script_1, args, {at::randn(shapes), at::randn(shapes)});
|
|
}
|
|
};
|
|
run_test({2, 3, 1, 2});
|
|
run_test({2, 6});
|
|
}
|
|
|
|
// test memory reuse
|
|
TEST(StaticRuntime, LongModel) {
|
|
torch::jit::Module mod = getLongScriptModel();
|
|
auto a = torch::randn({2, 2});
|
|
auto b = torch::randn({2, 2});
|
|
auto c = torch::randn({2, 2});
|
|
|
|
// run jit graph executor
|
|
std::vector<at::IValue> input_ivalues({a, b, c});
|
|
at::Tensor output_1 = mod.forward(input_ivalues).toTensor();
|
|
|
|
// run static runtime
|
|
std::vector<c10::IValue> input_tensors({a, b, c});
|
|
torch::jit::StaticModule smod(mod);
|
|
at::Tensor output_2 = smod(input_tensors, {}).toTensor();
|
|
smod.runtime().check_for_memory_leak();
|
|
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
|
|
}
|
|
|
|
TEST(StaticRuntime, TrivialModel) {
|
|
torch::jit::Module mod = getTrivialScriptModel();
|
|
auto a = torch::randn({2, 2});
|
|
auto b = torch::randn({2, 2});
|
|
auto c = torch::randn({2, 2});
|
|
|
|
// run jit graph executor
|
|
std::vector<at::IValue> input_ivalues({a, b, c});
|
|
at::Tensor output_1 = mod.forward(input_ivalues).toTensor();
|
|
|
|
// run static runtime
|
|
std::vector<c10::IValue> input_tensors({a, b, c});
|
|
torch::jit::StaticModule smod(mod);
|
|
at::Tensor output_2 = smod(input_tensors, {}).toTensor();
|
|
smod.runtime().check_for_memory_leak();
|
|
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
|
|
}
|
|
|
|
TEST(StaticRuntime, DeepWide) {
|
|
const int embedding_size = 32;
|
|
const int num_features = 50;
|
|
torch::jit::Module mod = getDeepAndWideSciptModel();
|
|
torch::jit::StaticModule smod(mod);
|
|
|
|
for (int batch_size : {1, 8, 32}) {
|
|
for (int i = 0; i < 2; ++i) {
|
|
auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
|
|
auto user_emb = torch::randn({batch_size, 1, embedding_size});
|
|
auto wide = torch::randn({batch_size, num_features});
|
|
|
|
// run jit graph executor
|
|
std::vector<at::IValue> inputs({ad_emb_packed, user_emb, wide});
|
|
auto output_1 = getTensor(mod.forward(inputs));
|
|
|
|
// run static runtime
|
|
std::vector<c10::IValue> input_tensors({ad_emb_packed, user_emb, wide});
|
|
auto outputs = smod(input_tensors, {}).toTupleRef().elements();
|
|
ASSERT_TRUE(outputs.size() > 0);
|
|
at::Tensor output_2 = outputs[0].toTensor();
|
|
smod.runtime().check_for_memory_leak();
|
|
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
|
|
}
|
|
}
|
|
}
|
|
|
|
TEST(StaticRuntime, KWargsAPI_1) {
|
|
const int embedding_size = 32;
|
|
const int num_features = 50;
|
|
auto module = getDeepAndWideSciptModel();
|
|
torch::jit::StaticModule smod(module);
|
|
|
|
for (int batch_size : {1, 8, 32}) {
|
|
for (int i = 0; i < 2; ++i) {
|
|
auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
|
|
auto user_emb = torch::randn({batch_size, 1, embedding_size});
|
|
auto wide = torch::randn({batch_size, num_features});
|
|
{
|
|
std::vector<at::IValue> inputs({ad_emb_packed, user_emb, wide});
|
|
|
|
// run jit graph executor
|
|
at::Tensor output_1 = getTensor(module.forward(inputs));
|
|
|
|
// run static runtime
|
|
c10::IValue output_ivalue = smod(inputs, {});
|
|
smod.runtime().check_for_memory_leak();
|
|
|
|
at::Tensor output_2 = getTensor(output_ivalue);
|
|
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
|
|
|
|
// check for output aliasing
|
|
EXPECT_EQ(output_ivalue.use_count(), 1);
|
|
output_ivalue = IValue();
|
|
|
|
EXPECT_EQ(output_2.getIntrusivePtr().use_count(), 1);
|
|
}
|
|
|
|
// check for input aliasing (deep & wide does not have ops
|
|
// that create aliases of input tensors)
|
|
EXPECT_EQ(ad_emb_packed.getIntrusivePtr().use_count(), 1);
|
|
EXPECT_EQ(user_emb.getIntrusivePtr().use_count(), 1);
|
|
EXPECT_EQ(wide.getIntrusivePtr().use_count(), 1);
|
|
}
|
|
}
|
|
}
|
|
|
|
TEST(StaticRuntime, KWargsAPI_2) {
|
|
const int embedding_size = 32;
|
|
const int num_features = 50;
|
|
auto module = getDeepAndWideSciptModel();
|
|
torch::jit::StaticModule smod(module);
|
|
|
|
for (int batch_size : {1, 8, 32}) {
|
|
for (int i = 0; i < 2; ++i) {
|
|
auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
|
|
auto user_emb = torch::randn({batch_size, 1, embedding_size});
|
|
auto wide = torch::randn({batch_size, num_features});
|
|
{
|
|
// run jit graph executor
|
|
std::vector<at::IValue> args({ad_emb_packed, user_emb, wide});
|
|
at::Tensor output_1 = getTensor(module.forward(args));
|
|
|
|
std::unordered_map<std::string, c10::IValue> kwargs(
|
|
{{"ad_emb_packed", ad_emb_packed},
|
|
{"user_emb", user_emb},
|
|
{"wide", wide}});
|
|
|
|
// run static runtime
|
|
c10::IValue output_ivalue = smod(std::vector<IValue>{}, kwargs);
|
|
smod.runtime().check_for_memory_leak();
|
|
|
|
at::Tensor output_2 = getTensor(output_ivalue);
|
|
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
|
|
|
|
// check for output aliasing
|
|
EXPECT_EQ(output_ivalue.use_count(), 1);
|
|
output_ivalue = IValue();
|
|
|
|
EXPECT_EQ(output_2.getIntrusivePtr().use_count(), 1);
|
|
}
|
|
|
|
EXPECT_EQ(ad_emb_packed.getIntrusivePtr().use_count(), 1);
|
|
EXPECT_EQ(user_emb.getIntrusivePtr().use_count(), 1);
|
|
EXPECT_EQ(wide.getIntrusivePtr().use_count(), 1);
|
|
}
|
|
}
|
|
}
|
|
|
|
TEST(StaticRuntime, CleanUpMemory) {
|
|
const int embedding_size = 32;
|
|
const int num_features = 50;
|
|
torch::jit::Module mod = getDeepAndWideSciptModel();
|
|
|
|
for (auto cleanup_activations : {true, false}) {
|
|
for (auto enable_out_variant : {true, false}) {
|
|
for (auto optimize_memory : {true, false}) {
|
|
for (auto manage_output_tensors : {true, false}) {
|
|
if (manage_output_tensors && !enable_out_variant) {
|
|
// when manage_output_tensors is enabled, enable_out_variant
|
|
// must be enabled too
|
|
continue;
|
|
}
|
|
if (optimize_memory && !enable_out_variant) {
|
|
// when optimize_memory is enabled, enable_out_variant must be
|
|
// enabled too
|
|
continue;
|
|
}
|
|
VLOG(1) << "cleanup_activations: " << cleanup_activations
|
|
<< ", enable_out_variant: " << enable_out_variant
|
|
<< ", optimize_memory: " << optimize_memory
|
|
<< ", manage_output_tensors: " << manage_output_tensors;
|
|
torch::jit::StaticModuleOptions opts{
|
|
cleanup_activations,
|
|
enable_out_variant,
|
|
optimize_memory,
|
|
manage_output_tensors};
|
|
torch::jit::StaticModule smod(mod, false, opts);
|
|
torch::jit::StaticRuntime runtime(smod);
|
|
|
|
for (int batch_size : {1, 8, 32}) {
|
|
for (int i = 0; i < 2; ++i) {
|
|
auto ad_emb_packed =
|
|
torch::randn({batch_size, 1, embedding_size});
|
|
auto user_emb = torch::randn({batch_size, 1, embedding_size});
|
|
auto wide = torch::randn({batch_size, num_features});
|
|
|
|
// run jit graph executor
|
|
std::vector<at::IValue> inputs({ad_emb_packed, user_emb, wide});
|
|
auto output_1 = getTensor(mod.forward(inputs));
|
|
|
|
// run static runtime
|
|
std::vector<c10::IValue> input_tensors(
|
|
{ad_emb_packed, user_emb, wide});
|
|
auto outputs = runtime(input_tensors, {}).toTupleRef().elements();
|
|
ASSERT_TRUE(outputs.size() > 0);
|
|
auto output_2 = outputs[0].toTensor();
|
|
runtime.check_for_memory_leak();
|
|
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
|
|
if (manage_output_tensors) {
|
|
runtime.deallocateOutputTensors();
|
|
runtime.checkOutputTensorMemoryLeaks();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
TEST(StaticRuntime, ManageOutputTensors) {
|
|
const std::string test_graph = R"IR(
|
|
graph(%0 : Tensor):
|
|
# With manage_output_tensor enabled, this tensor is managed.
|
|
%1 : Tensor = aten::abs(%0)
|
|
# The output container object is never managed.
|
|
%2 : (Tensor) = prim::TupleConstruct(%1)
|
|
return (%2)
|
|
)IR";
|
|
auto a = at::randn({2, 2});
|
|
auto b = at::randn({3, 6});
|
|
std::vector<at::IValue> args{a};
|
|
std::vector<at::IValue> args2{b};
|
|
testStaticRuntime(test_graph, args);
|
|
testStaticRuntime(test_graph, args, args2);
|
|
}
|
|
|
|
TEST(
|
|
StaticRuntime,
|
|
ManageOutputTensorsReturnsOutputContainingManagedOutputTensor) {
|
|
const std::string test_graph = R"IR(
|
|
graph(%0 : Tensor):
|
|
# With manage_output_tensor enabled, this tensor is managed.
|
|
%1 : Tensor = aten::abs(%0)
|
|
# The output container object is never managed.
|
|
%2 : (Tensor) = prim::TupleConstruct(%1)
|
|
return (%2)
|
|
)IR";
|
|
auto g = std::make_shared<torch::jit::Graph>();
|
|
torch::jit::parseIR(test_graph, g.get());
|
|
torch::jit::StaticModuleOptions opts{
|
|
/*cleanup_activations=*/true,
|
|
/*enable_out_variant=*/true,
|
|
/*optimize_memory=*/true,
|
|
/*manage_output_tensors=*/true};
|
|
auto a = at::randn({2, 2});
|
|
std::vector<at::IValue> args{a};
|
|
torch::jit::StaticModule smod(g, opts);
|
|
torch::jit::StaticRuntime runtime(smod);
|
|
// Profile run.
|
|
{
|
|
IValue tuple = runtime(args, {});
|
|
ASSERT_TRUE(tuple.isTuple());
|
|
ASSERT_EQ(tuple.toTupleRef().elements().size(), 1);
|
|
// Do not manage intput value.
|
|
EXPECT_FALSE(runtime.isManagedOutputTensor(args[0]));
|
|
// Do not manage direct output value.
|
|
EXPECT_FALSE(runtime.isManagedOutputTensor(tuple));
|
|
IValue element = tuple.toTupleRef().elements()[0];
|
|
// Tensor to be managed, but not yet from the profile run.
|
|
EXPECT_FALSE(runtime.isManagedOutputTensor(element));
|
|
tuple = IValue();
|
|
runtime.deallocateOutputTensors();
|
|
runtime.checkOutputTensorMemoryLeaks();
|
|
}
|
|
// Second run that manages output tensors.
|
|
{
|
|
IValue tuple = runtime(args, {});
|
|
ASSERT_TRUE(tuple.isTuple());
|
|
ASSERT_EQ(tuple.toTupleRef().elements().size(), 1);
|
|
// Do not manage intput value.
|
|
EXPECT_FALSE(runtime.isManagedOutputTensor(args[0]));
|
|
// Do not manage direct output value.
|
|
EXPECT_FALSE(runtime.isManagedOutputTensor(tuple));
|
|
IValue element = tuple.toTupleRef().elements()[0];
|
|
// Tensor to be managed, but not yet from the profile run.
|
|
EXPECT_TRUE(runtime.isManagedOutputTensor(element));
|
|
tuple = IValue();
|
|
runtime.deallocateOutputTensors();
|
|
runtime.checkOutputTensorMemoryLeaks();
|
|
}
|
|
}
|
|
|
|
TEST(StaticRuntime, ManageOutputTensorsWithDeallocateOutputTensors) {
|
|
const int embedding_size = 32;
|
|
const int num_features = 50;
|
|
torch::jit::Module mod = getDeepAndWideSciptModel();
|
|
|
|
torch::jit::StaticModuleOptions opts{
|
|
/*cleanup_activations=*/true,
|
|
/*enable_out_variant=*/true,
|
|
/*optimize_memory=*/true,
|
|
/*manage_output_tensors=*/true};
|
|
torch::jit::StaticModule smod(mod, false, opts);
|
|
torch::jit::StaticRuntime runtime(smod);
|
|
// Reenter the runtime with the input with the same shape/different shapes.
|
|
for (int batch_size : {8, 8, 24, 8}) {
|
|
auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
|
|
auto user_emb = torch::randn({batch_size, 1, embedding_size});
|
|
auto wide = torch::randn({batch_size, num_features});
|
|
std::vector<c10::IValue> input_tensors({ad_emb_packed, user_emb, wide});
|
|
runtime(input_tensors, {});
|
|
runtime.check_for_memory_leak();
|
|
runtime.deallocateOutputTensors();
|
|
runtime.checkOutputTensorMemoryLeaks();
|
|
}
|
|
}
|
|
|
|
TEST(StaticRuntime, ManageOutputTensorsWithoutDeallocateOutputTensors) {
|
|
const int embedding_size = 32;
|
|
const int num_features = 50;
|
|
torch::jit::Module mod = getDeepAndWideSciptModel();
|
|
|
|
torch::jit::StaticModuleOptions opts{
|
|
/*cleanup_activations=*/true,
|
|
/*enable_out_variant=*/true,
|
|
/*optimize_memory=*/true,
|
|
/*manage_output_tensors=*/true};
|
|
torch::jit::StaticModule smod(mod, false, opts);
|
|
torch::jit::StaticRuntime runtime(smod);
|
|
int batch_size = 8;
|
|
auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
|
|
auto user_emb = torch::randn({batch_size, 1, embedding_size});
|
|
auto wide = torch::randn({batch_size, num_features});
|
|
std::vector<c10::IValue> input_tensors({ad_emb_packed, user_emb, wide});
|
|
// Profile run.
|
|
runtime(input_tensors, {});
|
|
runtime.deallocateOutputTensors();
|
|
// Run again to allocate output Tensors without deallocating them.
|
|
runtime(input_tensors, {});
|
|
// Memory leak checking fails.
|
|
EXPECT_THROW(runtime.checkOutputTensorMemoryLeaks(), std::exception);
|
|
// Calling the runtime without deallocation fails too.
|
|
EXPECT_THROW(runtime(input_tensors, {}), std::exception);
|
|
// After deallocation, everything works fine.
|
|
runtime.deallocateOutputTensors();
|
|
runtime.checkOutputTensorMemoryLeaks();
|
|
runtime(input_tensors, {});
|
|
}
|
|
|
|
TEST(StaticRuntime, DisableManageOutputTensors) {
|
|
const std::string test_graph = R"IR(
|
|
graph(%0 : Tensor):
|
|
# With manage_output_tensor enabled, this tensor is managed.
|
|
%1 : Tensor = aten::abs(%0)
|
|
# The output container object is never managed.
|
|
%2 : (Tensor) = prim::TupleConstruct(%1)
|
|
return (%2)
|
|
)IR";
|
|
auto g = std::make_shared<torch::jit::Graph>();
|
|
torch::jit::parseIR(test_graph, g.get());
|
|
torch::jit::StaticModuleOptions opts{
|
|
/*cleanup_activations=*/true,
|
|
/*enable_out_variant=*/true,
|
|
/*optimize_memory=*/true,
|
|
/*manage_output_tensors=*/true};
|
|
auto a = at::randn({2, 2});
|
|
std::vector<at::IValue> args{a};
|
|
torch::jit::StaticModule smod(g, opts);
|
|
torch::jit::StaticRuntime runtime(smod);
|
|
// Profile run.
|
|
{
|
|
IValue tuple = runtime(args, {});
|
|
IValue element = tuple.toTupleRef().elements()[0];
|
|
EXPECT_FALSE(runtime.isManagedOutputTensor(element));
|
|
tuple = IValue();
|
|
runtime.deallocateOutputTensors();
|
|
runtime.checkOutputTensorMemoryLeaks();
|
|
}
|
|
// Second run that manages output tensors.
|
|
{
|
|
IValue tuple = runtime(args, {});
|
|
IValue element = tuple.toTupleRef().elements()[0];
|
|
EXPECT_TRUE(runtime.isManagedOutputTensor(element));
|
|
tuple = IValue();
|
|
runtime.deallocateOutputTensors();
|
|
runtime.checkOutputTensorMemoryLeaks();
|
|
}
|
|
|
|
// Reset the runtime and start profiling again.
|
|
runtime.disableManageOutputTensors();
|
|
|
|
IValue copied_output_tensor;
|
|
IValue original_output_tensor;
|
|
// New profile run.
|
|
{
|
|
IValue tuple = runtime(args, {});
|
|
IValue element = tuple.toTupleRef().elements()[0];
|
|
EXPECT_FALSE(runtime.isManagedOutputTensor(element));
|
|
copied_output_tensor = element.deepcopy();
|
|
original_output_tensor = element;
|
|
tuple = IValue();
|
|
// No-op since manage_output_tensor is disabled now.
|
|
runtime.deallocateOutputTensors();
|
|
runtime.checkOutputTensorMemoryLeaks();
|
|
}
|
|
// Ensure that `original_output_tensor` is no longer managed: even after
|
|
// calling `runtime.deallocateOutputTensors();` `original_output_tensor` still
|
|
// contains a valid value.
|
|
EXPECT_TRUE(
|
|
original_output_tensor.toTensor().equal(copied_output_tensor.toTensor()));
|
|
|
|
// Ensure that the second optimized run does not manage the output tensor
|
|
// either.
|
|
{
|
|
IValue tuple = runtime(args, {});
|
|
IValue element = tuple.toTupleRef().elements()[0];
|
|
EXPECT_FALSE(runtime.isManagedOutputTensor(element));
|
|
copied_output_tensor = element.deepcopy();
|
|
original_output_tensor = element;
|
|
tuple = IValue();
|
|
// No-op since manage_output_tensor is disabled now.
|
|
runtime.deallocateOutputTensors();
|
|
runtime.checkOutputTensorMemoryLeaks();
|
|
}
|
|
// Ensure that `original_output_tensor` is no longer managed: even after
|
|
// calling `runtime.deallocateOutputTensors();` `original_output_tensor` still
|
|
// contains a valid value.
|
|
EXPECT_TRUE(
|
|
original_output_tensor.toTensor().equal(copied_output_tensor.toTensor()));
|
|
}
|
|
|
|
TEST(StaticRuntime, FusionPass) {
|
|
const int embedding_size = 32;
|
|
const int num_features = 50;
|
|
for (int batch_size : {1, 8, 32}) {
|
|
for (int i = 0; i < 2; ++i) {
|
|
torch::jit::Module module = getDeepAndWideSciptModel();
|
|
auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
|
|
auto user_emb = torch::randn({batch_size, 1, embedding_size});
|
|
auto wide = torch::randn({batch_size, num_features});
|
|
|
|
// run jit graph executor
|
|
std::vector<at::IValue> inputs({ad_emb_packed, user_emb, wide});
|
|
auto output_1 = getTensor(module.forward(inputs));
|
|
|
|
Method method = module.get_method("forward");
|
|
auto graph = method.graph();
|
|
fuseStaticSubgraphs(graph, 2);
|
|
bool hit = false;
|
|
for (const auto& n : module.get_method("forward").graph()->nodes()) {
|
|
if (n->kind() == torch::jit::prim::StaticSubgraph) {
|
|
hit = true;
|
|
}
|
|
}
|
|
EXPECT_TRUE(hit);
|
|
auto output_2 = getTensor(module.forward(inputs));
|
|
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
|
|
}
|
|
}
|
|
}
|
|
|
|
static ProcessedNodeInputs createProcessedNodeInputs(
|
|
c10::ArrayRef<uint16_t> inputs) {
|
|
ProcessedNodeInputs result(inputs.size());
|
|
for (const auto idx : c10::irange(inputs.size())) {
|
|
result[idx] = inputs[idx];
|
|
}
|
|
return result;
|
|
}
|
|
|
|
TEST(
|
|
ProcessedNode,
|
|
VerifyNoMemoryOverlapWithImmutableInputsWithImmutableArguments) {
|
|
const auto sigmoid_script = R"JIT(
|
|
def forward(self, inp: Tensor):
|
|
b = torch.sigmoid(inp).clone()
|
|
return (b)
|
|
)JIT";
|
|
script::Module module("module");
|
|
// Not using out= variant.
|
|
module.define(sigmoid_script);
|
|
torch::jit::StaticModule smodule(module);
|
|
Node* sigmoid_node = getNodeWithKind(smodule, "aten::sigmoid");
|
|
std::array<IValue, 2> values = {torch::randn({2, 3}), torch::randn({3, 1})};
|
|
ProcessedFunction fn(
|
|
sigmoid_node,
|
|
/*enable_out_variant=*/true,
|
|
/*check_memory_overlap=*/false);
|
|
ProcessedNode pnode(sigmoid_node, &fn, createProcessedNodeInputs({0}), 1);
|
|
pnode.set_values(values.data());
|
|
EXPECT_TRUE(pnode.verify_no_memory_overlap());
|
|
|
|
pnode.Output(0) = values[0];
|
|
EXPECT_FALSE(pnode.verify_no_memory_overlap());
|
|
}
|
|
|
|
TEST(
|
|
ProcessedNode,
|
|
VerifyNoMemoryOverlapWithImmutableInputsWithMutableArguments) {
|
|
script::Module module("module");
|
|
// Using out= variant.
|
|
module.define(sigmoid_inplace_script);
|
|
torch::jit::StaticModule smodule(module);
|
|
Node* sigmoid_node = getNodeWithKind(smodule, "aten::sigmoid");
|
|
std::array<IValue, 2> values = {torch::randn({2, 3}), torch::randn({3, 1})};
|
|
ProcessedFunction fn(
|
|
sigmoid_node,
|
|
/*enable_out_variant=*/true,
|
|
/*check_memory_overlap=*/false);
|
|
ProcessedNode pnode(sigmoid_node, &fn, createProcessedNodeInputs({0}), 1);
|
|
pnode.set_values(values.data());
|
|
|
|
ASSERT_EQ(&pnode.Output(0), &values[1]);
|
|
EXPECT_TRUE(pnode.verify_no_memory_overlap());
|
|
|
|
pnode.Output(0) = values[0];
|
|
EXPECT_TRUE(pnode.verify_no_memory_overlap());
|
|
}
|
|
|
|
TEST(ProcessedNode, VerifyNoMemoryOverlapWithOverlappingOutputs) {
|
|
auto g = std::make_shared<torch::jit::Graph>();
|
|
torch::jit::parseIR(
|
|
R"IR(
|
|
graph(%0):
|
|
%1 : Tensor, %2 : Tensor = prim::ListUnpack(%0)
|
|
return (%1, %2))IR",
|
|
g.get());
|
|
torch::jit::StaticModule smodule(g);
|
|
Node* list_unpack_node = getNodeWithKind(smodule, "prim::ListUnpack");
|
|
{
|
|
std::array<IValue, 3> values = {
|
|
at::randn({2, 3}), at::empty({1, 3}), at::empty({4, 5})};
|
|
ProcessedFunction fn(
|
|
list_unpack_node,
|
|
/*enable_out_variant=*/true,
|
|
/*check_memory_overlap */ false);
|
|
ProcessedNode list_unpack_pnode(
|
|
list_unpack_node, &fn, createProcessedNodeInputs({0}), 1);
|
|
list_unpack_pnode.set_values(values.data());
|
|
ASSERT_EQ(list_unpack_pnode.outputs().size(), 2);
|
|
EXPECT_TRUE(list_unpack_pnode.verify_no_memory_overlap());
|
|
}
|
|
{
|
|
std::array<IValue, 3> values = {
|
|
at::randn({2, 3}), at::empty({1, 3}), at::empty({4, 5})};
|
|
ProcessedFunction fn(
|
|
list_unpack_node,
|
|
/*enable_out_variant=*/true,
|
|
/*check_memory_overlap */ false);
|
|
ProcessedNode list_unpack_pnode(
|
|
list_unpack_node, &fn, createProcessedNodeInputs({0}), 1);
|
|
list_unpack_pnode.set_values(values.data());
|
|
auto b = at::randn({2, 3});
|
|
list_unpack_pnode.Output(0) = b;
|
|
list_unpack_pnode.Output(1) = b;
|
|
EXPECT_FALSE(list_unpack_pnode.verify_no_memory_overlap());
|
|
}
|
|
}
|
|
|
|
namespace test {
|
|
at::Tensor bad_add(const at::Tensor& self, int64_t b) {
|
|
if (b == 0) {
|
|
return self;
|
|
}
|
|
return at::native::add(self, b);
|
|
}
|
|
|
|
at::Tensor good_add(const at::Tensor& self, int64_t b) {
|
|
if (b == 0) {
|
|
return self;
|
|
}
|
|
return at::native::add(self, b);
|
|
}
|
|
} // namespace test
|
|
|
|
// test::bad_add has the schema with incorrect alias annotation.
|
|
// test::good_add has the correct alias annotation.
|
|
TORCH_LIBRARY_FRAGMENT(test, m) {
|
|
m.def("bad_add(Tensor self, int b=0) -> Tensor");
|
|
m.def("good_add(Tensor(a) self, int b=0) -> Tensor(a)");
|
|
}
|
|
TORCH_LIBRARY_IMPL(test, CPU, m) {
|
|
m.impl("bad_add", ::test::bad_add);
|
|
m.impl("good_add", ::test::good_add);
|
|
}
|
|
|
|
TEST(StaticRuntime, BadSchemaAliasInfo) {
|
|
const std::string src = R"IR(
|
|
graph(%x: Tensor, %s: int):
|
|
%c0 : int = prim::Constant[value=0]()
|
|
%c1 : int = prim::Constant[value=1]()
|
|
%a = aten::add(%x, %x, %c1)
|
|
%b1 = test::bad_add(%a, %s) # b1 aliases a
|
|
%t : (Tensor) = prim::TupleConstruct(%b1)
|
|
return (%t)
|
|
)IR";
|
|
|
|
const auto x1 = at::randn({2, 2});
|
|
// big enough to trigger resize of the internal buffer
|
|
const auto x2 = at::randn({3, 6});
|
|
testStaticRuntime(src, {x1, 0}, {x2, 10});
|
|
// This test doesn't pass yet. This is the corner case mentioned in Step 2 of
|
|
// [Check and correct bad schema alias info at runtime]
|
|
// testStaticRuntime(src, {x1, 10}, {x2, 0});
|
|
}
|
|
|
|
// This test repeats the last test, but with the correct schema alias
|
|
// annotations
|
|
TEST(StaticRuntime, GoodSchemaAliasInfo) {
|
|
// comment out the prim::TupleConstruct repro the failure of
|
|
// DCHECK(!isManagedOutputTensor(*outputs_[0]));
|
|
const std::string src = R"IR(
|
|
graph(%x: Tensor, %s: int):
|
|
%c0 : int = prim::Constant[value=0]()
|
|
%c1 : int = prim::Constant[value=1]()
|
|
%a = aten::add(%x, %x, %c1)
|
|
%b1 = test::good_add(%a, %s) # b1 aliases a
|
|
# return (%b1)
|
|
%t : (Tensor) = prim::TupleConstruct(%b1)
|
|
return (%t)
|
|
)IR";
|
|
|
|
const auto x1 = at::randn({2, 2});
|
|
// big enough to trigger resize of the internal buffer
|
|
const auto x2 = at::randn({3, 6});
|
|
testStaticRuntime(src, {x1, 0}, {x2, 10});
|
|
testStaticRuntime(src, {x1, 10}, {x2, 0});
|
|
}
|
|
|
|
TEST(ProcessedFunction, ProcessedFunction) {
|
|
const auto script = R"JIT(
|
|
def forward(self, inp: Tensor):
|
|
b = torch.sigmoid(inp).clone()
|
|
c = torch.transpose(b, 0, 1)
|
|
return (c)
|
|
)JIT";
|
|
script::Module module("module");
|
|
module.define(script);
|
|
torch::jit::StaticModule smodule(module);
|
|
|
|
Node* sigmoid_node = getNodeWithKind(smodule, "aten::sigmoid");
|
|
ProcessedFunction sigmoid_fn(
|
|
sigmoid_node,
|
|
/*enable_out_variant=*/true,
|
|
/*check_memory_overlap=*/false);
|
|
EXPECT_TRUE(sigmoid_fn.f());
|
|
EXPECT_EQ(sigmoid_fn.kind(), ProcessedFunction::Kind::kOutVariant);
|
|
EXPECT_FALSE(sigmoid_fn.checkMemoryOverlap());
|
|
|
|
Node* transpose_node = getNodeWithKind(smodule, "aten::transpose");
|
|
ProcessedFunction transpose_fn(
|
|
transpose_node,
|
|
/*enable_out_variant=*/true,
|
|
/*check_memory_overlap=*/false);
|
|
EXPECT_TRUE(transpose_fn.f());
|
|
EXPECT_EQ(transpose_fn.kind(), ProcessedFunction::Kind::kNativeFunction);
|
|
EXPECT_FALSE(transpose_fn.checkMemoryOverlap());
|
|
}
|