[JIT] Exit Transform Rewrite (#38282)

Summary:
After an early return, we conditionalize all further execution. This means that currently the pattern of
`if return elif return elif return` generates better code than `if return if return if return`. It's obviously not good to have semantically equivalent code generate worse IR, so we should rewrite the graph to handle this case. This came up in https://github.com/pytorch/pytorch/pull/37171

```
torch.jit.script
def test_foo(x: bool, y: bool):
    if x:
        return 1
    return 2
print(test_foo.code)
```
generates:
```
def test_foo(x: bool,
    y: bool) -> int:
  _0 = uninitialized(int)
  if x:
    _1, _2 = True, 1
  else:
    _1, _2 = False, _0
  if _1:
    _3 = _2
  else:
    _3 = 2
  return _3
```
while
```
torch.jit.script
def test_foo(x: bool, y: bool):
    if x:
        return 1
    else:
        return 2
print(test_foo.code)
```
generates:
```
def test_foo(x: bool,
    y: bool) -> int:
  if x:
    _0 = 1
  else:
    _0 = 2
  return _0
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38282

Differential Revision: D21576733

Pulled By: eellison

fbshipit-source-id: 80cf1ad7fbda6d8d58557abbfb21c90eafae7488
This commit is contained in:
Elias Ellison 2020-05-15 12:20:13 -07:00 committed by Facebook GitHub Bot
parent 62afc2d63d
commit daa85cfe2e
5 changed files with 152 additions and 21 deletions

View file

@ -15461,7 +15461,7 @@ a")
output = torch.tanh(self)
def backward(grad_output):
a = 1
if True:
if output:
return 1
else:
a = 2
@ -15614,8 +15614,8 @@ a")
self.checkScript(test_loop_no_escape, (-1,))
self.checkScriptRaisesRegex(test_loop_no_escape, (1,), Exception, "")
# one if added to guard x + 3, the throw in loop does not escape
test_num_ifs(test_loop_no_escape, 2)
# if guard gets optimized away
test_num_ifs(test_loop_no_escape, 1)
def test_loop_exception_with_continue(x):
# type: (int)
@ -15659,8 +15659,8 @@ a")
func = torch.jit.CompilationUnit(code).test_exit_pair_reset
self.assertEqual(func(1,), 2)
self.assertEqual(func(-1,), -1)
FileCheck().check_count("prim::If", 2, exactly=True).check("aten::add")\
.run(func.graph) # if added to guard a + 1
# final a + 1 gets inlined into the first branch and optimized away
FileCheck().check_count("prim::If", 1, exactly=True).run(func.graph)
def test_non_final_return(self):
def simple(x):
@ -15756,21 +15756,6 @@ a")
for i in range(4):
self.checkScript(complicated, (i,))
def test_partial_returns_shape_prop(self):
@torch.jit.script
def test_shape_prop(x):
# type: (int) -> int
if not bool(x):
return x
else:
z = torch.zeros([2, 2], dtype=torch.int64)
return int(z[0])
test_shape_prop(torch.tensor(0.5))
graph = test_shape_prop.graph_for(torch.tensor(0.5))
# Shape analysis of z should propagate through if statement
FileCheck().check("Long(2:2, 2:1)").check("prim::If").run(graph)
def test_partial_returns(self):
with self.assertRaisesRegex(RuntimeError, "does not return along all"):
@torch.jit.script
@ -15881,6 +15866,45 @@ a")
FileCheck().check_not("prim::PythonOp").run(cu.test.graph)
def test_early_return_rewrite(self):
def test_foo(x: bool):
if x:
return 1
return 2
self.checkScript(test_foo, (True,))
self.checkScript(test_foo, (False,))
FileCheck().check_count("prim::If", 1, exactly=True).run(torch.jit.script(test_foo).graph)
def test_multiple(x: int):
if x == 5:
return x * x
else:
y = 2 * x
z = y * 2
if z == 8:
return 1
if z != 16:
z = z - 2
abc = 4
else:
return 3
z = z * abc
return z * z * z
self.checkScript(test_multiple, (5,))
self.checkScript(test_multiple, (2,))
self.checkScript(test_multiple, (4,))
self.checkScript(test_multiple, (3,))
self.checkScript(test_multiple, (10,))
graph = torch.jit.script(test_multiple).graph
FileCheck().check_count("prim::If", 3, exactly=True).run(graph)
print(torch.jit.script(test_multiple).code)
def test_is_scripting_metacompile(self):
@torch.jit.script
def foo():

View file

@ -473,6 +473,7 @@ class TestFuser(JitTestCase):
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "broken with profiling on")
@torch.jit._disable_emit_hooks_decorator
@_inline_everything
def test_fuse_decompose_normalization(self):
class ResLike(torch.jit.ScriptModule):

View file

@ -500,6 +500,7 @@ class TestFuser(JitTestCase):
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "broken with profiling on")
@torch.jit._disable_emit_hooks_decorator
@_inline_everything
def test_fuse_decompose_normalization(self):
class ResLike(torch.jit.ScriptModule):

View file

@ -505,6 +505,101 @@ struct ExitTransformer {
std::shared_ptr<Graph> graph_;
};
bool inlineConsecutiveIfs(Node* node) {
if (node->kind() != prim::If || node->next()->kind() != prim::If) {
return false;
}
IfView first_if(node);
IfView second_if(node->next());
// the second if must depend on a value outputted in the first if for us to
// inline the second if
if (second_if.cond()->node() != node) {
return false;
}
// both blocks must output a constant value for us to inline, and those values
// must be different. if the values are the same, then the subsequent if node
// will get constant prop'd away, and inlining it into the first node would
// double code size
auto input_offset = second_if.cond()->offset();
auto maybe_then_value = toIValue(first_if.thenOutputs().at(input_offset));
auto maybe_else_value = toIValue(first_if.elseOutputs().at(input_offset));
if (!maybe_then_value || !maybe_else_value ||
maybe_then_value->toBool() == maybe_else_value->toBool()) {
return false;
}
bool then_value = maybe_then_value->toBool();
bool else_value = maybe_else_value->toBool();
for (auto i = 0; i < 2; ++i) {
Block* first_if_block;
Block* second_if_block;
if (i == 0) {
first_if_block = first_if.thenBlock();
second_if_block =
then_value ? second_if.thenBlock() : second_if.elseBlock();
} else {
first_if_block = first_if.elseBlock();
second_if_block =
else_value ? second_if.thenBlock() : second_if.elseBlock();
;
}
// we need to replace values that were used in the second if that were
// outputs of the first if with the equivalent value in the scope of the
// block we're copying into
auto value_map = [&](Value* v) {
if (v->node() != first_if.node()) {
return v;
}
auto offset = v->offset();
return first_if_block->outputs().at(offset);
};
// clone from also copies block outputs from second_if_block onto
// first_if_block
first_if_block->cloneFrom(second_if_block, value_map);
}
for (Value* output : second_if.outputs()) {
auto new_out = first_if.node()->addOutput()->copyMetadata(output);
output->replaceAllUsesWith(new_out);
}
second_if.node()->destroy();
return true;
}
// After an early return, we conditionalize all further execution
// This means code like the following:
// if x:
// return 1
// return 2
// Gets generated as one if statement checking `if x`, and then a second if
// statement that conditionalizes execution. We can rewrite cases like these
// into one if statement, so that the above examples gets rewritten to look
// like: if x:
// return 1
// else:
// return 2
void inlineConsecutiveIfs(Block* block) {
for (auto it = block->nodes().begin(), end = block->nodes().end();
it != end;) {
for (Block* b : it->blocks()) {
inlineConsecutiveIfs(b);
}
// if we fused two ifs, we need to check current node and new next node
if (!inlineConsecutiveIfs(*it)) {
it++;
}
}
}
// This pass takes in a graph where LoopContinuation & ReturnStmts exist in the
// graph and erases them in the graph, correctly setting block outputs.
// prim::LoopContinuation(*vals) means that the values are targeting the most
@ -586,6 +681,7 @@ void TransformExits(std::shared_ptr<Graph>& graph) {
e_loop.transformLoopContinuations();
ExitTransformer e_ret(graph);
e_ret.transformReturnStmts();
inlineConsecutiveIfs(graph->block());
}
} // namespace jit
} // namespace torch

View file

@ -13,6 +13,7 @@ from torch.nn import Module
from torch.serialization import validate_cuda_device
from torch._six import PY37, with_metaclass, string_classes, get_function_from_type
from torch.utils import set_module
from torch.autograd.grad_mode import _DecoratorContextManager
import collections
import contextlib
@ -1098,7 +1099,6 @@ def _try_get_overloaded_fn(mod, field):
class ScriptWarning(Warning):
pass
@contextlib.contextmanager
def _disable_emit_hooks():
hooks = torch._C._jit_get_emit_hooks()
@ -1107,6 +1107,15 @@ def _disable_emit_hooks():
torch._C._jit_set_emit_hooks(hooks[0], hooks[1])
def _disable_emit_hooks_decorator(_DecoratorContextManager): # noqa: F811
def __enter__(self):
self.hooks = torch._C._jit_get_emit_hooks()
torch._C._jit_set_emit_hooks(None, None)
def __exit__(self, *args):
torch._C._jit_set_emit_hooks(self.hooks[0], self.hooks[1])
# ScriptClasses must be new-style classes because we construct them using their
# __new__ method.
def _is_new_style_class(cls):