From d1c5f9e43993157f1c66e835c98624e4daef8564 Mon Sep 17 00:00:00 2001 From: Mike Iovine Date: Thu, 17 Feb 2022 10:18:33 -0800 Subject: [PATCH] [JIT][SR] Introduce prim::IfThenElse (#72587) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72587 This pattern frequently appears in a few graphs: ``` %result = prim::If(%condition) block0(): -> (%a) block1(): -> (%b) ``` This is slow, particularly in static runtime. Static runtime creates memory planners/block runners for each sub-block, which eats up a lot of memory and introduces a lot of extra overhead for this relatively simple operation. This diff introduces a new op that replaces nodes like the above with a single op meant to act like a ternary operator: ``` %result = prim::IfThenElse(%condition, %a, %b) ``` Test Plan: New unit tests Reviewed By: eellison Differential Revision: D34091789 fbshipit-source-id: eb6a8c460c39b4c019a1f4ab1f3f1e5b6edc400c (cherry picked from commit 0f1b335e5b83f402bda2dcdd9ecb411e0b67c651) --- aten/src/ATen/core/interned_strings.h | 1 + .../static_runtime/test_static_runtime.cc | 16 ++++++ test/cpp/jit/CMakeLists.txt | 1 + test/cpp/jit/test_add_if_then_else.cpp | 53 ++++++++++++++++++ tools/build_variables.bzl | 1 + torch/csrc/jit/passes/add_if_then_else.cpp | 55 +++++++++++++++++++ torch/csrc/jit/passes/add_if_then_else.h | 11 ++++ .../runtime/profiling_graph_executor_impl.cpp | 7 +++ .../runtime/profiling_graph_executor_impl.h | 1 + torch/csrc/jit/runtime/register_prim_ops.cpp | 11 ++++ torch/csrc/jit/runtime/static/impl.cpp | 5 +- torch/csrc/jit/runtime/static/impl.h | 3 +- torch/csrc/jit/runtime/static/native_ops.cpp | 12 ++++ 13 files changed, 175 insertions(+), 2 deletions(-) create mode 100644 test/cpp/jit/test_add_if_then_else.cpp create mode 100644 torch/csrc/jit/passes/add_if_then_else.cpp create mode 100644 torch/csrc/jit/passes/add_if_then_else.h diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index b2d6a43731f..88f275093d1 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -96,6 +96,7 @@ namespace c10 { _(prim, With) \ _(prim, Enter) \ _(prim, Exit) \ + _(prim, IfThenElse) \ _(aten, Bool) \ _(aten, Int) \ _(aten, FloatImplicit) \ diff --git a/benchmarks/static_runtime/test_static_runtime.cc b/benchmarks/static_runtime/test_static_runtime.cc index bc923e707e1..c3e9a050ff1 100644 --- a/benchmarks/static_runtime/test_static_runtime.cc +++ b/benchmarks/static_runtime/test_static_runtime.cc @@ -2720,3 +2720,19 @@ TEST(StaticRuntime, ToList) { )JIT"; testStaticRuntime(src, {at::randn({2, 2})}); } + +TEST(StaticRuntime, IfThenElse) { + const auto src = R"IR( + graph(%cond: bool, %a: Tensor, %b: Tensor): + %none: NoneType = prim::Constant() + %c: Tensor = prim::IfThenElse(%cond, %a, %b) + %d: Tensor = aten::clone(%c, %none) + return (%d) + )IR"; + + std::vector args1{true, at::randn({1}), at::randn({1})}; + std::vector args2{false, at::randn({1}), at::randn({1})}; + + testStaticRuntime(src, args1); + testStaticRuntime(src, args2); +} diff --git a/test/cpp/jit/CMakeLists.txt b/test/cpp/jit/CMakeLists.txt index cfdbb28a676..7358af08582 100644 --- a/test/cpp/jit/CMakeLists.txt +++ b/test/cpp/jit/CMakeLists.txt @@ -39,6 +39,7 @@ endif() # Build the cpp gtest binary containing the cpp-only tests. set(JIT_TEST_SRCS + ${JIT_TEST_ROOT}/test_add_if_then_else.cpp ${JIT_TEST_ROOT}/test_alias_analysis.cpp ${JIT_TEST_ROOT}/test_argument_spec.cpp ${JIT_TEST_ROOT}/test_autodiff.cpp diff --git a/test/cpp/jit/test_add_if_then_else.cpp b/test/cpp/jit/test_add_if_then_else.cpp new file mode 100644 index 00000000000..4850e1ab425 --- /dev/null +++ b/test/cpp/jit/test_add_if_then_else.cpp @@ -0,0 +1,53 @@ +#include + +#include +#include +#include + +namespace torch { +namespace jit { + +TEST(AddIfThenElseOpTest, AddIfThenElseOpSimple) { + const auto src = R"IR( + graph(%cond: bool, %a: Tensor, %b: Tensor): + %result: Tensor = prim::If(%cond) + block0(): + -> (%a) + block1(): + -> (%b) + return (%result) + )IR"; + + auto graph = std::make_shared(); + parseIR(src, graph.get()); + EXPECT_TRUE(AddIfThenElseOp(graph)); + + testing::FileCheck() + .check_count("= prim::IfThenElse", 1, /*exactly*/ true) + ->check_count("= prim::If", 0, /*exactly*/ true) + ->run(*graph); +} + +TEST(AddIfThenElseOpTest, NoIfThenElseOpMultipleOutputs) { + const auto src = R"IR( + graph(%cond: bool, %a: Tensor, %b: Tensor): + %result1: Tensor, %result2: Tensor = prim::If(%cond) + block0(): + -> (%a, %b) + block1(): + -> (%b, %a) + return (%result1, %result2) + )IR"; + + auto graph = std::make_shared(); + parseIR(src, graph.get()); + EXPECT_FALSE(AddIfThenElseOp(graph)); + + testing::FileCheck() + .check_count("= prim::IfThenElse", 0, /*exactly*/ true) + ->check_count("= prim::If", 1, /*exactly*/ true) + ->run(*graph); +} + +} // namespace jit +} // namespace torch diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index c6a7e5a0791..67f2def297c 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -213,6 +213,7 @@ core_sources_full_mobile_no_backend_interface = [ "torch/csrc/jit/operator_upgraders/utils.cpp", "torch/csrc/jit/operator_upgraders/upgraders.cpp", "torch/csrc/jit/operator_upgraders/upgraders_entry.cpp", + "torch/csrc/jit/passes/add_if_then_else.cpp", "torch/csrc/jit/passes/annotate_warns.cpp", "torch/csrc/jit/passes/bailout_graph.cpp", "torch/csrc/jit/passes/batch_mm.cpp", diff --git a/torch/csrc/jit/passes/add_if_then_else.cpp b/torch/csrc/jit/passes/add_if_then_else.cpp new file mode 100644 index 00000000000..72a085fd021 --- /dev/null +++ b/torch/csrc/jit/passes/add_if_then_else.cpp @@ -0,0 +1,55 @@ +#include +#include + +namespace torch { +namespace jit { + +namespace { + +bool hasNoNodes(Block* block) { + auto nodes = block->nodes(); + return nodes.begin() == nodes.end(); +} + +bool hasTrivialSubBlocks(Node* node) { + const auto blocks = node->blocks(); + DCHECK_EQ(blocks.size(), 2); + + return hasNoNodes(blocks[0]) && hasNoNodes(blocks[1]); +} + +} // namespace + +bool AddIfThenElseOp(std::shared_ptr& graph) { + std::vector to_replace; + DepthFirstGraphNodeIterator graph_it(graph); + for (auto* node = graph_it.next(); node != nullptr; node = graph_it.next()) { + if (node->kind() != prim::If) { + continue; + } + if (node->outputs().size() != 1) { + continue; + } + if (hasTrivialSubBlocks(node)) { + to_replace.push_back(node); + } + } + + for (auto* node : to_replace) { + auto* if_then_else_node = graph->create(prim::IfThenElse, 1); + if_then_else_node->addInput(node->input()); + auto blocks = node->blocks(); + if_then_else_node->addInput(blocks[0]->return_node()->input()); + if_then_else_node->addInput(blocks[1]->return_node()->input()); + + if_then_else_node->insertBefore(node); + if_then_else_node->output()->copyMetadata(node->output()); + + node->output()->replaceAllUsesWith(if_then_else_node->output()); + node->destroy(); + } + return !to_replace.empty(); +} + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/passes/add_if_then_else.h b/torch/csrc/jit/passes/add_if_then_else.h new file mode 100644 index 00000000000..c6b3f9376d6 --- /dev/null +++ b/torch/csrc/jit/passes/add_if_then_else.h @@ -0,0 +1,11 @@ +#pragma once + +#include + +namespace torch { +namespace jit { + +TORCH_API bool AddIfThenElseOp(std::shared_ptr& graph); + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp b/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp index c0fc02e34d4..66a71a08596 100644 --- a/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp +++ b/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -650,6 +651,7 @@ const ExecutionPlan& ProfilingGraphExecutorImpl::getOptimizedPlanFor( // replaces a fallback graph inserted by // specialize_autogradzero if one exists replaceFallbackGraphWithFallbackFunction(copy->block()); + runFinalOptimizations(copy); GRAPH_DUMP("Optimized Graph: ", copy); optimized_plan_ = ExecutionPlan(copy, function_name_, *remaining_bailout_depth_); @@ -749,5 +751,10 @@ void ProfilingGraphExecutorImpl::replaceFallbackGraphWithFallbackFunction( } } +void ProfilingGraphExecutorImpl::runFinalOptimizations( + std::shared_ptr& graph) { + AddIfThenElseOp(graph); +} + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/runtime/profiling_graph_executor_impl.h b/torch/csrc/jit/runtime/profiling_graph_executor_impl.h index 560eaca2cc3..117873934db 100644 --- a/torch/csrc/jit/runtime/profiling_graph_executor_impl.h +++ b/torch/csrc/jit/runtime/profiling_graph_executor_impl.h @@ -39,6 +39,7 @@ struct TORCH_API ProfilingGraphExecutorImpl : public GraphExecutorImplBase { std::shared_ptr& graph, size_t remaining_depth); void replaceFallbackGraphWithFallbackFunction(Block* b); + void runFinalOptimizations(std::shared_ptr& graph); std::unique_ptr pr_; c10::optional profiling_plan_; // plan to run in order to profiling the code diff --git a/torch/csrc/jit/runtime/register_prim_ops.cpp b/torch/csrc/jit/runtime/register_prim_ops.cpp index 0bf4f22aa7f..de8bc3ae86e 100644 --- a/torch/csrc/jit/runtime/register_prim_ops.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops.cpp @@ -700,6 +700,17 @@ static const std::vector opGenArgs{ push(stack, at::stack(inputs, dim)); }, aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA( + "prim::IfThenElse(bool cond, Any(a) x, Any(b) y) -> Any(a|b)"), + [](Stack& stack) { + const auto cond = stack[stack.size() - 3].toBool(); + stack[stack.size() - 3] = + std::move(stack[stack.size() - (cond ? 2 : 1)]); + stack.pop_back(); + stack.pop_back(); + }, + aliasAnalysisFromSchema()), OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA( "aten::eq.enum(AnyEnumType a, AnyEnumType b) -> bool"), diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp index f8984129582..595e428e535 100644 --- a/torch/csrc/jit/runtime/static/impl.cpp +++ b/torch/csrc/jit/runtime/static/impl.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -173,6 +174,7 @@ void OptimizeGraph( UseVariadicGroupedAccessor(graph); EliminateNoOps( graph, /* custom_ops */ {fromQualString("fb::scale_gradient")}); + AddIfThenElseOp(graph); GRAPH_DUMP("Final graph after optimizations: ", graph); } @@ -1846,8 +1848,9 @@ static bool checkNoMemoryOverlap(const at::Tensor& a, const at::Tensor& b) { } bool ProcessedNode::verify_no_memory_overlap(bool force_check) const { - const static std::array special_case_ops = { + const static std::array special_case_ops = { fromQualString("prim::TypeCheck"), + fromQualString("prim::IfThenElse"), fromQualString("static_runtime::select_tensor"), fromQualString("static_runtime::VarTupleUnpack"), fromQualString("static_runtime::dict_unpack"), diff --git a/torch/csrc/jit/runtime/static/impl.h b/torch/csrc/jit/runtime/static/impl.h index 27fcf2e5a24..6f3b0d9018a 100644 --- a/torch/csrc/jit/runtime/static/impl.h +++ b/torch/csrc/jit/runtime/static/impl.h @@ -58,10 +58,11 @@ TORCH_API inline bool doesNotHeapAllocateWhenStoredInIValue(const Type& type) { } TORCH_API inline bool borrowsOutputs(c10::Symbol kind) { - static const std::array symbols_with_borrowed_outputs = { + static const std::array symbols_with_borrowed_outputs = { c10::Symbol::fromQualString("static_runtime::select_tensor"), c10::Symbol::fromQualString("static_runtime::dict_unpack"), c10::Symbol::fromQualString("static_runtime::VarTupleUnpack"), + c10::Symbol::fromQualString("prim::IfThenElse"), }; return std::find( symbols_with_borrowed_outputs.begin(), diff --git a/torch/csrc/jit/runtime/static/native_ops.cpp b/torch/csrc/jit/runtime/static/native_ops.cpp index 33e2e27a7de..5d71e6b8135 100644 --- a/torch/csrc/jit/runtime/static/native_ops.cpp +++ b/torch/csrc/jit/runtime/static/native_ops.cpp @@ -946,5 +946,17 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR( }; }); +// See [Borrowed IValue Outputs] +REGISTER_NATIVE_OPERATOR_FUNCTOR( + prim::IfThenElse, + prim_IfThenElse, + [](Node*) -> SROperator { + return [](ProcessedNode* pnode) { + const auto condition = pnode->Input(0).toBool(); + pnode->Output(0) = condition ? createBorrowedIValue(pnode->Input(1)) + : createBorrowedIValue(pnode->Input(2)); + }; + }); + } // namespace jit } // namespace torch