mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
fix inliner bug (#25052)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/25052 Previously we would not inline nested functions, now we do. Test Plan: Imported from OSS Differential Revision: D16973848 Pulled By: suo fbshipit-source-id: 94aa0b6f84a2577a663f4e219f930d2c6396d585
This commit is contained in:
parent
18c77dd243
commit
fa902c58ee
5 changed files with 87 additions and 11 deletions
70
test/cpp/jit/test_inliner.cpp
Normal file
70
test/cpp/jit/test_inliner.cpp
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
#include <test/cpp/jit/test_base.h>
|
||||
|
||||
#include <torch/csrc/jit/passes/inliner.h>
|
||||
#include <torch/csrc/jit/script/compilation_unit.h>
|
||||
#include <torch/csrc/jit/script/module.h>
|
||||
#include <torch/csrc/jit/testing/file_check.h>
|
||||
|
||||
const auto testSource = R"JIT(
|
||||
def foo1(x):
|
||||
print("one")
|
||||
return x
|
||||
|
||||
def foo2(x):
|
||||
print("two")
|
||||
return foo1(x)
|
||||
|
||||
def foo3(x):
|
||||
print("three")
|
||||
return foo2(x)
|
||||
)JIT";
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
using namespace script;
|
||||
using namespace testing;
|
||||
|
||||
struct InlinerGuard {
|
||||
explicit InlinerGuard(bool shouldInline)
|
||||
: oldState_(getInlineEverythingMode()) {
|
||||
getInlineEverythingMode() = shouldInline;
|
||||
}
|
||||
|
||||
~InlinerGuard() {
|
||||
getInlineEverythingMode() = oldState_;
|
||||
}
|
||||
|
||||
bool oldState_;
|
||||
};
|
||||
|
||||
void testInliner() {
|
||||
{
|
||||
// Test that the recursive inlining works
|
||||
// disable automatic inlining so we can test it manually
|
||||
InlinerGuard guard(/*shouldInline=*/false);
|
||||
|
||||
CompilationUnit cu(testSource);
|
||||
auto& fn = cu.get_function("foo3");
|
||||
|
||||
auto g = fn.graph();
|
||||
Inline(*g, /*recurse=*/true);
|
||||
FileCheck().check_count("prim::Print", 3)->run(*g);
|
||||
}
|
||||
{
|
||||
// disable automatic inlining so we can test it manually
|
||||
InlinerGuard guard(/*shouldInline=*/false);
|
||||
|
||||
CompilationUnit cu(testSource);
|
||||
auto& fn = cu.get_function("foo3");
|
||||
|
||||
auto g = fn.graph();
|
||||
Inline(*g, /*recurse=*/false);
|
||||
FileCheck()
|
||||
.check("three")
|
||||
->check("two")
|
||||
->check_count("prim::CallFunction", 1)
|
||||
->run(*g);
|
||||
}
|
||||
}
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
@ -62,7 +62,8 @@ namespace jit {
|
|||
_(DCE) \
|
||||
_(CustomFusionNestedBlocks) \
|
||||
_(ImportTooNew) \
|
||||
_(ClassDerive)
|
||||
_(ClassDerive) \
|
||||
_(Inliner)
|
||||
|
||||
#define TH_FORALL_TESTS_CUDA(_) \
|
||||
_(ArgumentSpec) \
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ namespace prim {
|
|||
using namespace ::c10::prim;
|
||||
}
|
||||
|
||||
void inlineCalls(Block* block) {
|
||||
void inlineCalls(Block* block, bool recurse) {
|
||||
for (auto it = block->nodes().begin(), end = block->nodes().end();
|
||||
it != end;) {
|
||||
Node* cur = *it++;
|
||||
|
|
@ -20,29 +20,32 @@ void inlineCalls(Block* block) {
|
|||
auto fun_type =
|
||||
function_constant->output()->type()->expect<FunctionType>();
|
||||
cur->removeInput(0);
|
||||
inlineCallTo(cur, *fun_type->function()->graph());
|
||||
if (!function_constant->hasUses()) {
|
||||
function_constant->destroy();
|
||||
if (recurse) {
|
||||
Inline(*fun_type->function()->graph(), recurse);
|
||||
}
|
||||
inlineCallTo(cur, *fun_type->function()->graph());
|
||||
} break;
|
||||
case prim::CallMethod: {
|
||||
const std::string& name = cur->s(attr::name);
|
||||
if (auto class_type = cur->input(0)->type()->cast<ClassType>()) {
|
||||
auto function = class_type->getMethod(name);
|
||||
if (recurse) {
|
||||
Inline(*function->graph(), recurse);
|
||||
}
|
||||
inlineCallTo(cur, *function->graph());
|
||||
}
|
||||
} break;
|
||||
default: {
|
||||
for (auto b : cur->blocks()) {
|
||||
inlineCalls(b);
|
||||
inlineCalls(b, recurse);
|
||||
}
|
||||
} break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Inline(Graph& graph) {
|
||||
inlineCalls(graph.block());
|
||||
void Inline(Graph& graph, bool recurse) {
|
||||
inlineCalls(graph.block(), recurse);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
|
|
|
|||
|
|
@ -5,7 +5,9 @@
|
|||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
TORCH_API void Inline(Graph& graph);
|
||||
// Inline function and method calls. If `recurse` is true, inline all nested
|
||||
// calls as well, resulting in a completely flattened graph.
|
||||
TORCH_API void Inline(Graph& graph, bool recurse = false);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -86,8 +86,8 @@ def _split_tensor_list_constants(g, block):
|
|||
|
||||
|
||||
def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=False):
|
||||
# Inline everyting
|
||||
torch._C._jit_pass_inline(graph)
|
||||
# Inline everyting (recursively)
|
||||
torch._C._jit_pass_inline(graph, True)
|
||||
|
||||
# Remove fork/wait nodes
|
||||
torch._C._jit_pass_inline_fork_wait(graph)
|
||||
|
|
|
|||
Loading…
Reference in a new issue