From daa85cfe2edb0a1b6cfda5593fa972b47d4bc646 Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Fri, 15 May 2020 12:20:13 -0700 Subject: [PATCH] [JIT] Exit Transform Rewrite (#38282) Summary: After an early return, we conditionalize all further execution. This means that currently the pattern of `if return elif return elif return` generates better code than `if return if return if return`. It's obviously not good to have semantically equivalent code generate worse IR, so we should rewrite the graph to handle this case. This came up in https://github.com/pytorch/pytorch/pull/37171 ``` torch.jit.script def test_foo(x: bool, y: bool): if x: return 1 return 2 print(test_foo.code) ``` generates: ``` def test_foo(x: bool, y: bool) -> int: _0 = uninitialized(int) if x: _1, _2 = True, 1 else: _1, _2 = False, _0 if _1: _3 = _2 else: _3 = 2 return _3 ``` while ``` torch.jit.script def test_foo(x: bool, y: bool): if x: return 1 else: return 2 print(test_foo.code) ``` generates: ``` def test_foo(x: bool, y: bool) -> int: if x: _0 = 1 else: _0 = 2 return _0 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/38282 Differential Revision: D21576733 Pulled By: eellison fbshipit-source-id: 80cf1ad7fbda6d8d58557abbfb21c90eafae7488 --- test/test_jit.py | 64 +++++++++----- test/test_jit_fuser.py | 1 + test/test_jit_fuser_te.py | 1 + torch/csrc/jit/frontend/exit_transforms.cpp | 96 +++++++++++++++++++++ torch/jit/__init__.py | 11 ++- 5 files changed, 152 insertions(+), 21 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index 922d7595117..d102b557fdb 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -15461,7 +15461,7 @@ a") output = torch.tanh(self) def backward(grad_output): a = 1 - if True: + if output: return 1 else: a = 2 @@ -15614,8 +15614,8 @@ a") self.checkScript(test_loop_no_escape, (-1,)) self.checkScriptRaisesRegex(test_loop_no_escape, (1,), Exception, "") - # one if added to guard x + 3, the throw in loop does not escape - test_num_ifs(test_loop_no_escape, 2) + # if guard gets optimized away + test_num_ifs(test_loop_no_escape, 1) def test_loop_exception_with_continue(x): # type: (int) @@ -15659,8 +15659,8 @@ a") func = torch.jit.CompilationUnit(code).test_exit_pair_reset self.assertEqual(func(1,), 2) self.assertEqual(func(-1,), -1) - FileCheck().check_count("prim::If", 2, exactly=True).check("aten::add")\ - .run(func.graph) # if added to guard a + 1 + # final a + 1 gets inlined into the first branch and optimized away + FileCheck().check_count("prim::If", 1, exactly=True).run(func.graph) def test_non_final_return(self): def simple(x): @@ -15756,21 +15756,6 @@ a") for i in range(4): self.checkScript(complicated, (i,)) - def test_partial_returns_shape_prop(self): - @torch.jit.script - def test_shape_prop(x): - # type: (int) -> int - if not bool(x): - return x - else: - z = torch.zeros([2, 2], dtype=torch.int64) - return int(z[0]) - - test_shape_prop(torch.tensor(0.5)) - graph = test_shape_prop.graph_for(torch.tensor(0.5)) - # Shape analysis of z should propagate through if statement - FileCheck().check("Long(2:2, 2:1)").check("prim::If").run(graph) - def test_partial_returns(self): with self.assertRaisesRegex(RuntimeError, "does not return along all"): @torch.jit.script @@ -15881,6 +15866,45 @@ a") FileCheck().check_not("prim::PythonOp").run(cu.test.graph) + def test_early_return_rewrite(self): + def test_foo(x: bool): + if x: + return 1 + return 2 + + self.checkScript(test_foo, (True,)) + self.checkScript(test_foo, (False,)) + FileCheck().check_count("prim::If", 1, exactly=True).run(torch.jit.script(test_foo).graph) + + def test_multiple(x: int): + if x == 5: + return x * x + else: + y = 2 * x + + z = y * 2 + if z == 8: + return 1 + + if z != 16: + z = z - 2 + abc = 4 + else: + return 3 + + z = z * abc + return z * z * z + + self.checkScript(test_multiple, (5,)) + self.checkScript(test_multiple, (2,)) + self.checkScript(test_multiple, (4,)) + self.checkScript(test_multiple, (3,)) + self.checkScript(test_multiple, (10,)) + + graph = torch.jit.script(test_multiple).graph + FileCheck().check_count("prim::If", 3, exactly=True).run(graph) + print(torch.jit.script(test_multiple).code) + def test_is_scripting_metacompile(self): @torch.jit.script def foo(): diff --git a/test/test_jit_fuser.py b/test/test_jit_fuser.py index 668157d5dba..ff573c499d9 100644 --- a/test/test_jit_fuser.py +++ b/test/test_jit_fuser.py @@ -473,6 +473,7 @@ class TestFuser(JitTestCase): @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "broken with profiling on") + @torch.jit._disable_emit_hooks_decorator @_inline_everything def test_fuse_decompose_normalization(self): class ResLike(torch.jit.ScriptModule): diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index d87e87013c6..b762103950e 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -500,6 +500,7 @@ class TestFuser(JitTestCase): @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "broken with profiling on") + @torch.jit._disable_emit_hooks_decorator @_inline_everything def test_fuse_decompose_normalization(self): class ResLike(torch.jit.ScriptModule): diff --git a/torch/csrc/jit/frontend/exit_transforms.cpp b/torch/csrc/jit/frontend/exit_transforms.cpp index 69e0e8cf2c1..94d2cf0fb6e 100644 --- a/torch/csrc/jit/frontend/exit_transforms.cpp +++ b/torch/csrc/jit/frontend/exit_transforms.cpp @@ -505,6 +505,101 @@ struct ExitTransformer { std::shared_ptr graph_; }; +bool inlineConsecutiveIfs(Node* node) { + if (node->kind() != prim::If || node->next()->kind() != prim::If) { + return false; + } + + IfView first_if(node); + IfView second_if(node->next()); + + // the second if must depend on a value outputted in the first if for us to + // inline the second if + if (second_if.cond()->node() != node) { + return false; + } + + // both blocks must output a constant value for us to inline, and those values + // must be different. if the values are the same, then the subsequent if node + // will get constant prop'd away, and inlining it into the first node would + // double code size + + auto input_offset = second_if.cond()->offset(); + auto maybe_then_value = toIValue(first_if.thenOutputs().at(input_offset)); + auto maybe_else_value = toIValue(first_if.elseOutputs().at(input_offset)); + if (!maybe_then_value || !maybe_else_value || + maybe_then_value->toBool() == maybe_else_value->toBool()) { + return false; + } + + bool then_value = maybe_then_value->toBool(); + bool else_value = maybe_else_value->toBool(); + + for (auto i = 0; i < 2; ++i) { + Block* first_if_block; + Block* second_if_block; + + if (i == 0) { + first_if_block = first_if.thenBlock(); + second_if_block = + then_value ? second_if.thenBlock() : second_if.elseBlock(); + } else { + first_if_block = first_if.elseBlock(); + second_if_block = + else_value ? second_if.thenBlock() : second_if.elseBlock(); + ; + } + + // we need to replace values that were used in the second if that were + // outputs of the first if with the equivalent value in the scope of the + // block we're copying into + auto value_map = [&](Value* v) { + if (v->node() != first_if.node()) { + return v; + } + auto offset = v->offset(); + return first_if_block->outputs().at(offset); + }; + + // clone from also copies block outputs from second_if_block onto + // first_if_block + first_if_block->cloneFrom(second_if_block, value_map); + } + + for (Value* output : second_if.outputs()) { + auto new_out = first_if.node()->addOutput()->copyMetadata(output); + output->replaceAllUsesWith(new_out); + } + second_if.node()->destroy(); + return true; +} + +// After an early return, we conditionalize all further execution +// This means code like the following: +// if x: +// return 1 +// return 2 +// Gets generated as one if statement checking `if x`, and then a second if +// statement that conditionalizes execution. We can rewrite cases like these +// into one if statement, so that the above examples gets rewritten to look +// like: if x: +// return 1 +// else: +// return 2 +void inlineConsecutiveIfs(Block* block) { + for (auto it = block->nodes().begin(), end = block->nodes().end(); + it != end;) { + for (Block* b : it->blocks()) { + inlineConsecutiveIfs(b); + } + + // if we fused two ifs, we need to check current node and new next node + if (!inlineConsecutiveIfs(*it)) { + it++; + } + } +} + // This pass takes in a graph where LoopContinuation & ReturnStmts exist in the // graph and erases them in the graph, correctly setting block outputs. // prim::LoopContinuation(*vals) means that the values are targeting the most @@ -586,6 +681,7 @@ void TransformExits(std::shared_ptr& graph) { e_loop.transformLoopContinuations(); ExitTransformer e_ret(graph); e_ret.transformReturnStmts(); + inlineConsecutiveIfs(graph->block()); } } // namespace jit } // namespace torch diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index bab1ebcc3bc..06f9810cadf 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -13,6 +13,7 @@ from torch.nn import Module from torch.serialization import validate_cuda_device from torch._six import PY37, with_metaclass, string_classes, get_function_from_type from torch.utils import set_module +from torch.autograd.grad_mode import _DecoratorContextManager import collections import contextlib @@ -1098,7 +1099,6 @@ def _try_get_overloaded_fn(mod, field): class ScriptWarning(Warning): pass - @contextlib.contextmanager def _disable_emit_hooks(): hooks = torch._C._jit_get_emit_hooks() @@ -1107,6 +1107,15 @@ def _disable_emit_hooks(): torch._C._jit_set_emit_hooks(hooks[0], hooks[1]) +def _disable_emit_hooks_decorator(_DecoratorContextManager): # noqa: F811 + def __enter__(self): + self.hooks = torch._C._jit_get_emit_hooks() + torch._C._jit_set_emit_hooks(None, None) + + def __exit__(self, *args): + torch._C._jit_set_emit_hooks(self.hooks[0], self.hooks[1]) + + # ScriptClasses must be new-style classes because we construct them using their # __new__ method. def _is_new_style_class(cls):