fix inliner bug (#25052)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25052

Previously we would not inline nested functions, now we do.

Test Plan: Imported from OSS

Differential Revision: D16973848

Pulled By: suo

fbshipit-source-id: 94aa0b6f84a2577a663f4e219f930d2c6396d585
This commit is contained in:
Michael Suo 2019-08-28 19:44:31 -07:00 committed by Facebook Github Bot
parent 18c77dd243
commit fa902c58ee
5 changed files with 87 additions and 11 deletions

View file

@ -0,0 +1,70 @@
#include <test/cpp/jit/test_base.h>
#include <torch/csrc/jit/passes/inliner.h>
#include <torch/csrc/jit/script/compilation_unit.h>
#include <torch/csrc/jit/script/module.h>
#include <torch/csrc/jit/testing/file_check.h>
const auto testSource = R"JIT(
def foo1(x):
print("one")
return x
def foo2(x):
print("two")
return foo1(x)
def foo3(x):
print("three")
return foo2(x)
)JIT";
namespace torch {
namespace jit {
using namespace script;
using namespace testing;
struct InlinerGuard {
explicit InlinerGuard(bool shouldInline)
: oldState_(getInlineEverythingMode()) {
getInlineEverythingMode() = shouldInline;
}
~InlinerGuard() {
getInlineEverythingMode() = oldState_;
}
bool oldState_;
};
void testInliner() {
{
// Test that the recursive inlining works
// disable automatic inlining so we can test it manually
InlinerGuard guard(/*shouldInline=*/false);
CompilationUnit cu(testSource);
auto& fn = cu.get_function("foo3");
auto g = fn.graph();
Inline(*g, /*recurse=*/true);
FileCheck().check_count("prim::Print", 3)->run(*g);
}
{
// disable automatic inlining so we can test it manually
InlinerGuard guard(/*shouldInline=*/false);
CompilationUnit cu(testSource);
auto& fn = cu.get_function("foo3");
auto g = fn.graph();
Inline(*g, /*recurse=*/false);
FileCheck()
.check("three")
->check("two")
->check_count("prim::CallFunction", 1)
->run(*g);
}
}
} // namespace jit
} // namespace torch

View file

@ -62,7 +62,8 @@ namespace jit {
_(DCE) \
_(CustomFusionNestedBlocks) \
_(ImportTooNew) \
_(ClassDerive)
_(ClassDerive) \
_(Inliner)
#define TH_FORALL_TESTS_CUDA(_) \
_(ArgumentSpec) \

View file

@ -9,7 +9,7 @@ namespace prim {
using namespace ::c10::prim;
}
void inlineCalls(Block* block) {
void inlineCalls(Block* block, bool recurse) {
for (auto it = block->nodes().begin(), end = block->nodes().end();
it != end;) {
Node* cur = *it++;
@ -20,29 +20,32 @@ void inlineCalls(Block* block) {
auto fun_type =
function_constant->output()->type()->expect<FunctionType>();
cur->removeInput(0);
inlineCallTo(cur, *fun_type->function()->graph());
if (!function_constant->hasUses()) {
function_constant->destroy();
if (recurse) {
Inline(*fun_type->function()->graph(), recurse);
}
inlineCallTo(cur, *fun_type->function()->graph());
} break;
case prim::CallMethod: {
const std::string& name = cur->s(attr::name);
if (auto class_type = cur->input(0)->type()->cast<ClassType>()) {
auto function = class_type->getMethod(name);
if (recurse) {
Inline(*function->graph(), recurse);
}
inlineCallTo(cur, *function->graph());
}
} break;
default: {
for (auto b : cur->blocks()) {
inlineCalls(b);
inlineCalls(b, recurse);
}
} break;
}
}
}
void Inline(Graph& graph) {
inlineCalls(graph.block());
void Inline(Graph& graph, bool recurse) {
inlineCalls(graph.block(), recurse);
}
} // namespace jit

View file

@ -5,7 +5,9 @@
namespace torch {
namespace jit {
TORCH_API void Inline(Graph& graph);
// Inline function and method calls. If `recurse` is true, inline all nested
// calls as well, resulting in a completely flattened graph.
TORCH_API void Inline(Graph& graph, bool recurse = false);
} // namespace jit
} // namespace torch

View file

@ -86,8 +86,8 @@ def _split_tensor_list_constants(g, block):
def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=False):
# Inline everyting
torch._C._jit_pass_inline(graph)
# Inline everyting (recursively)
torch._C._jit_pass_inline(graph, True)
# Remove fork/wait nodes
torch._C._jit_pass_inline_fork_wait(graph)