mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[JIT] Exit Transform Rewrite (#38282)
Summary: After an early return, we conditionalize all further execution. This means that currently the pattern of `if return elif return elif return` generates better code than `if return if return if return`. It's obviously not good to have semantically equivalent code generate worse IR, so we should rewrite the graph to handle this case. This came up in https://github.com/pytorch/pytorch/pull/37171 ``` torch.jit.script def test_foo(x: bool, y: bool): if x: return 1 return 2 print(test_foo.code) ``` generates: ``` def test_foo(x: bool, y: bool) -> int: _0 = uninitialized(int) if x: _1, _2 = True, 1 else: _1, _2 = False, _0 if _1: _3 = _2 else: _3 = 2 return _3 ``` while ``` torch.jit.script def test_foo(x: bool, y: bool): if x: return 1 else: return 2 print(test_foo.code) ``` generates: ``` def test_foo(x: bool, y: bool) -> int: if x: _0 = 1 else: _0 = 2 return _0 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/38282 Differential Revision: D21576733 Pulled By: eellison fbshipit-source-id: 80cf1ad7fbda6d8d58557abbfb21c90eafae7488
This commit is contained in:
parent
62afc2d63d
commit
daa85cfe2e
5 changed files with 152 additions and 21 deletions
|
|
@ -15461,7 +15461,7 @@ a")
|
|||
output = torch.tanh(self)
|
||||
def backward(grad_output):
|
||||
a = 1
|
||||
if True:
|
||||
if output:
|
||||
return 1
|
||||
else:
|
||||
a = 2
|
||||
|
|
@ -15614,8 +15614,8 @@ a")
|
|||
self.checkScript(test_loop_no_escape, (-1,))
|
||||
self.checkScriptRaisesRegex(test_loop_no_escape, (1,), Exception, "")
|
||||
|
||||
# one if added to guard x + 3, the throw in loop does not escape
|
||||
test_num_ifs(test_loop_no_escape, 2)
|
||||
# if guard gets optimized away
|
||||
test_num_ifs(test_loop_no_escape, 1)
|
||||
|
||||
def test_loop_exception_with_continue(x):
|
||||
# type: (int)
|
||||
|
|
@ -15659,8 +15659,8 @@ a")
|
|||
func = torch.jit.CompilationUnit(code).test_exit_pair_reset
|
||||
self.assertEqual(func(1,), 2)
|
||||
self.assertEqual(func(-1,), -1)
|
||||
FileCheck().check_count("prim::If", 2, exactly=True).check("aten::add")\
|
||||
.run(func.graph) # if added to guard a + 1
|
||||
# final a + 1 gets inlined into the first branch and optimized away
|
||||
FileCheck().check_count("prim::If", 1, exactly=True).run(func.graph)
|
||||
|
||||
def test_non_final_return(self):
|
||||
def simple(x):
|
||||
|
|
@ -15756,21 +15756,6 @@ a")
|
|||
for i in range(4):
|
||||
self.checkScript(complicated, (i,))
|
||||
|
||||
def test_partial_returns_shape_prop(self):
|
||||
@torch.jit.script
|
||||
def test_shape_prop(x):
|
||||
# type: (int) -> int
|
||||
if not bool(x):
|
||||
return x
|
||||
else:
|
||||
z = torch.zeros([2, 2], dtype=torch.int64)
|
||||
return int(z[0])
|
||||
|
||||
test_shape_prop(torch.tensor(0.5))
|
||||
graph = test_shape_prop.graph_for(torch.tensor(0.5))
|
||||
# Shape analysis of z should propagate through if statement
|
||||
FileCheck().check("Long(2:2, 2:1)").check("prim::If").run(graph)
|
||||
|
||||
def test_partial_returns(self):
|
||||
with self.assertRaisesRegex(RuntimeError, "does not return along all"):
|
||||
@torch.jit.script
|
||||
|
|
@ -15881,6 +15866,45 @@ a")
|
|||
|
||||
FileCheck().check_not("prim::PythonOp").run(cu.test.graph)
|
||||
|
||||
def test_early_return_rewrite(self):
|
||||
def test_foo(x: bool):
|
||||
if x:
|
||||
return 1
|
||||
return 2
|
||||
|
||||
self.checkScript(test_foo, (True,))
|
||||
self.checkScript(test_foo, (False,))
|
||||
FileCheck().check_count("prim::If", 1, exactly=True).run(torch.jit.script(test_foo).graph)
|
||||
|
||||
def test_multiple(x: int):
|
||||
if x == 5:
|
||||
return x * x
|
||||
else:
|
||||
y = 2 * x
|
||||
|
||||
z = y * 2
|
||||
if z == 8:
|
||||
return 1
|
||||
|
||||
if z != 16:
|
||||
z = z - 2
|
||||
abc = 4
|
||||
else:
|
||||
return 3
|
||||
|
||||
z = z * abc
|
||||
return z * z * z
|
||||
|
||||
self.checkScript(test_multiple, (5,))
|
||||
self.checkScript(test_multiple, (2,))
|
||||
self.checkScript(test_multiple, (4,))
|
||||
self.checkScript(test_multiple, (3,))
|
||||
self.checkScript(test_multiple, (10,))
|
||||
|
||||
graph = torch.jit.script(test_multiple).graph
|
||||
FileCheck().check_count("prim::If", 3, exactly=True).run(graph)
|
||||
print(torch.jit.script(test_multiple).code)
|
||||
|
||||
def test_is_scripting_metacompile(self):
|
||||
@torch.jit.script
|
||||
def foo():
|
||||
|
|
|
|||
|
|
@ -473,6 +473,7 @@ class TestFuser(JitTestCase):
|
|||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "broken with profiling on")
|
||||
@torch.jit._disable_emit_hooks_decorator
|
||||
@_inline_everything
|
||||
def test_fuse_decompose_normalization(self):
|
||||
class ResLike(torch.jit.ScriptModule):
|
||||
|
|
|
|||
|
|
@ -500,6 +500,7 @@ class TestFuser(JitTestCase):
|
|||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "broken with profiling on")
|
||||
@torch.jit._disable_emit_hooks_decorator
|
||||
@_inline_everything
|
||||
def test_fuse_decompose_normalization(self):
|
||||
class ResLike(torch.jit.ScriptModule):
|
||||
|
|
|
|||
|
|
@ -505,6 +505,101 @@ struct ExitTransformer {
|
|||
std::shared_ptr<Graph> graph_;
|
||||
};
|
||||
|
||||
bool inlineConsecutiveIfs(Node* node) {
|
||||
if (node->kind() != prim::If || node->next()->kind() != prim::If) {
|
||||
return false;
|
||||
}
|
||||
|
||||
IfView first_if(node);
|
||||
IfView second_if(node->next());
|
||||
|
||||
// the second if must depend on a value outputted in the first if for us to
|
||||
// inline the second if
|
||||
if (second_if.cond()->node() != node) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// both blocks must output a constant value for us to inline, and those values
|
||||
// must be different. if the values are the same, then the subsequent if node
|
||||
// will get constant prop'd away, and inlining it into the first node would
|
||||
// double code size
|
||||
|
||||
auto input_offset = second_if.cond()->offset();
|
||||
auto maybe_then_value = toIValue(first_if.thenOutputs().at(input_offset));
|
||||
auto maybe_else_value = toIValue(first_if.elseOutputs().at(input_offset));
|
||||
if (!maybe_then_value || !maybe_else_value ||
|
||||
maybe_then_value->toBool() == maybe_else_value->toBool()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool then_value = maybe_then_value->toBool();
|
||||
bool else_value = maybe_else_value->toBool();
|
||||
|
||||
for (auto i = 0; i < 2; ++i) {
|
||||
Block* first_if_block;
|
||||
Block* second_if_block;
|
||||
|
||||
if (i == 0) {
|
||||
first_if_block = first_if.thenBlock();
|
||||
second_if_block =
|
||||
then_value ? second_if.thenBlock() : second_if.elseBlock();
|
||||
} else {
|
||||
first_if_block = first_if.elseBlock();
|
||||
second_if_block =
|
||||
else_value ? second_if.thenBlock() : second_if.elseBlock();
|
||||
;
|
||||
}
|
||||
|
||||
// we need to replace values that were used in the second if that were
|
||||
// outputs of the first if with the equivalent value in the scope of the
|
||||
// block we're copying into
|
||||
auto value_map = [&](Value* v) {
|
||||
if (v->node() != first_if.node()) {
|
||||
return v;
|
||||
}
|
||||
auto offset = v->offset();
|
||||
return first_if_block->outputs().at(offset);
|
||||
};
|
||||
|
||||
// clone from also copies block outputs from second_if_block onto
|
||||
// first_if_block
|
||||
first_if_block->cloneFrom(second_if_block, value_map);
|
||||
}
|
||||
|
||||
for (Value* output : second_if.outputs()) {
|
||||
auto new_out = first_if.node()->addOutput()->copyMetadata(output);
|
||||
output->replaceAllUsesWith(new_out);
|
||||
}
|
||||
second_if.node()->destroy();
|
||||
return true;
|
||||
}
|
||||
|
||||
// After an early return, we conditionalize all further execution
|
||||
// This means code like the following:
|
||||
// if x:
|
||||
// return 1
|
||||
// return 2
|
||||
// Gets generated as one if statement checking `if x`, and then a second if
|
||||
// statement that conditionalizes execution. We can rewrite cases like these
|
||||
// into one if statement, so that the above examples gets rewritten to look
|
||||
// like: if x:
|
||||
// return 1
|
||||
// else:
|
||||
// return 2
|
||||
void inlineConsecutiveIfs(Block* block) {
|
||||
for (auto it = block->nodes().begin(), end = block->nodes().end();
|
||||
it != end;) {
|
||||
for (Block* b : it->blocks()) {
|
||||
inlineConsecutiveIfs(b);
|
||||
}
|
||||
|
||||
// if we fused two ifs, we need to check current node and new next node
|
||||
if (!inlineConsecutiveIfs(*it)) {
|
||||
it++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This pass takes in a graph where LoopContinuation & ReturnStmts exist in the
|
||||
// graph and erases them in the graph, correctly setting block outputs.
|
||||
// prim::LoopContinuation(*vals) means that the values are targeting the most
|
||||
|
|
@ -586,6 +681,7 @@ void TransformExits(std::shared_ptr<Graph>& graph) {
|
|||
e_loop.transformLoopContinuations();
|
||||
ExitTransformer e_ret(graph);
|
||||
e_ret.transformReturnStmts();
|
||||
inlineConsecutiveIfs(graph->block());
|
||||
}
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from torch.nn import Module
|
|||
from torch.serialization import validate_cuda_device
|
||||
from torch._six import PY37, with_metaclass, string_classes, get_function_from_type
|
||||
from torch.utils import set_module
|
||||
from torch.autograd.grad_mode import _DecoratorContextManager
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
|
|
@ -1098,7 +1099,6 @@ def _try_get_overloaded_fn(mod, field):
|
|||
class ScriptWarning(Warning):
|
||||
pass
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _disable_emit_hooks():
|
||||
hooks = torch._C._jit_get_emit_hooks()
|
||||
|
|
@ -1107,6 +1107,15 @@ def _disable_emit_hooks():
|
|||
torch._C._jit_set_emit_hooks(hooks[0], hooks[1])
|
||||
|
||||
|
||||
def _disable_emit_hooks_decorator(_DecoratorContextManager): # noqa: F811
|
||||
def __enter__(self):
|
||||
self.hooks = torch._C._jit_get_emit_hooks()
|
||||
torch._C._jit_set_emit_hooks(None, None)
|
||||
|
||||
def __exit__(self, *args):
|
||||
torch._C._jit_set_emit_hooks(self.hooks[0], self.hooks[1])
|
||||
|
||||
|
||||
# ScriptClasses must be new-style classes because we construct them using their
|
||||
# __new__ method.
|
||||
def _is_new_style_class(cls):
|
||||
|
|
|
|||
Loading…
Reference in a new issue