From 5153cdbe8754aadd6bb7119feaca6ffb91bb6abc Mon Sep 17 00:00:00 2001 From: Nick Gibson Date: Sun, 31 May 2020 22:20:14 -0700 Subject: [PATCH] [TensorExpr] fix a bug in ReorderAxis when there are trailing loops (#38841) Summary: Fixes a bug in reorder axis where we append the new reordered loops to the enclosing block, even if there were statements after it. e.g. with 3 Computes: ``` for (int m1 ... for (int n1 ... for (int k1 ... Body 1 for (int m2 ... for (int n2 ... for (int k2 ... Body 2 for (int m3 ... for (int n3 ... for (int k3 ... Body 3 ``` If we reorder loops m2 and k2, we were also reordering the body statements like this: ``` for (int m1 ... for (int n1 ... for (int k1 ... Body 1 for (int m3 ... for (int n3 ... for (int k3 ... Body 3 for (int k2 ... for (int n2 ... for (int m2 ... Body 2 ``` This is because we always append the new loops to their parent. This PR fixes the logic to replace the old loop root with the new loop, which keeps things consistent. Pull Request resolved: https://github.com/pytorch/pytorch/pull/38841 Differential Revision: D21723670 Pulled By: nickgg fbshipit-source-id: 1dee8bb153182fcaa2cabd948197577e8e80acd7 --- test/cpp/tensorexpr/test_loopnest.cpp | 107 +++++++++++++++++++++++++ test/cpp/tensorexpr/tests.h | 1 + torch/csrc/jit/tensorexpr/analysis.h | 1 + torch/csrc/jit/tensorexpr/loopnest.cpp | 16 ++-- 4 files changed, 117 insertions(+), 8 deletions(-) diff --git a/test/cpp/tensorexpr/test_loopnest.cpp b/test/cpp/tensorexpr/test_loopnest.cpp index d18c374e4d6..e7a08a3a11e 100644 --- a/test/cpp/tensorexpr/test_loopnest.cpp +++ b/test/cpp/tensorexpr/test_loopnest.cpp @@ -5,6 +5,7 @@ #include #include +#include #include #include #include @@ -1630,6 +1631,112 @@ void testLoopNestReorderLongStringFull() { } } +void testLoopNestReorderInternalLoopNest() { + KernelScope kernel_scope; + const int M = 4; + const int N = 5; + const int K = 6; + Buffer a_buf("a", kFloat, {M, N}); + Buffer b_buf("b", kFloat, {N, K}); + Buffer c_buf("c", kFloat, {M, N}); + Buffer d_buf("d", kFloat, {M, K}); + + Tensor* x = Compute( + "x", + {{M, "m1"}, {N, "n1"}, {K, "k1"}}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return a_buf(m, n) * b_buf(n, k); + }); + Tensor* y = Compute( + "y", + {{M, "m2"}, {N, "n2"}, {K, "k2"}}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return c_buf(m, n) * d_buf(m, k) + x->call(m, n, k); + }); + Tensor* z = Compute( + "z", + {{M, "m3"}, {N, "n3"}, {K, "k3"}}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return x->call(m, n, k) + y->call(m, n, k); + }); + + LoopNest l({z}); + For* a = nullptr; + For* b = nullptr; + auto fors = NodeFinder::find(l.root_stmt()); + for (auto* f : fors) { + if (f->var()->name_hint() == "m2") { + a = f; + } else if (f->var()->name_hint() == "k2") { + b = f; + } + } + l.reorderAxis(a, b); + + l.prepareForCodegen(); + Stmt* stmt = IRSimplifier::simplify(l.root_stmt()); + + std::ostringstream oss; + oss << *stmt; + + // Check the IR we produced has the 3 nests in the right order, but k and m + // swapped in the middle. + const std::string& verification_pattern = + R"IR( +# CHECK: for (int m1 +# CHECK: for (int n1 +# CHECK: for (int k1 +# CHECK: for (int k2 +# CHECK: for (int n2 +# CHECK: for (int m2 +# CHECK: for (int m3 +# CHECK: for (int n3 +# CHECK: for (int k3)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + { + PaddedBuffer a_v(M, N); + PaddedBuffer b_v(N, K); + PaddedBuffer c_v(M, N); + PaddedBuffer d_v(M, K); + + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + a_v(i, j) = i * i; + } + } + for (int i = 0; i < N; i++) { + for (int j = 0; j < K; j++) { + b_v(i, j) = j * j; + } + } + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + c_v(i, j) = i + j; + } + } + for (int i = 0; i < M; i++) { + for (int j = 0; j < K; j++) { + d_v(i, j) = i * j; + } + } + + PaddedBuffer z_v(M, N, K); + PaddedBuffer z_ref(M, N, K); + for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { + for (int k = 0; k < K; k++) { + z_ref(m, n, k) = a_v(m, n) * b_v(n, k) * 2 + c_v(m, n) * d_v(m, k); + } + } + } + + SimpleIREvaluator eval(stmt, a_buf, b_buf, c_buf, d_buf, z); + eval(a_v, b_v, c_v, d_v, z_v); + ExpectAllNear(z_v, z_ref, 1e-5); + } +} + void testOuterLoopVectorization() { KernelScope kernel_scope; Tensor* tensor = Compute( diff --git a/test/cpp/tensorexpr/tests.h b/test/cpp/tensorexpr/tests.h index 8eb494529fe..e9ba01f8ad4 100644 --- a/test/cpp/tensorexpr/tests.h +++ b/test/cpp/tensorexpr/tests.h @@ -181,6 +181,7 @@ namespace jit { _(LoopNestReorderLongStringOfPreOrphans) \ _(LoopNestReorderLongStringOfPostOrphans) \ _(LoopNestReorderLongStringFull) \ + _(LoopNestReorderInternalLoopNest) \ _(OuterLoopVectorization) \ _(Kernel_1) \ _(Kernel_2) \ diff --git a/torch/csrc/jit/tensorexpr/analysis.h b/torch/csrc/jit/tensorexpr/analysis.h index febfe51053a..d66cc9896fd 100644 --- a/torch/csrc/jit/tensorexpr/analysis.h +++ b/torch/csrc/jit/tensorexpr/analysis.h @@ -45,6 +45,7 @@ class NodeFinder : public IRVisitor { std::vector nodes; }; + } // namespace tensorexpr } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp index d5c8b0dcbc2..3c0904c1c9f 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.cpp +++ b/torch/csrc/jit/tensorexpr/loopnest.cpp @@ -1145,12 +1145,6 @@ void LoopNest::reorderAxis(For* a, For* b) { } } - // If the top level is now empty, eliminate it. - if (before->body()->nstmts() == 0) { - root->remove_stmt(before); - before = nullptr; - } - // now we can actually reorder the chosen axes. std::swap(internal_axes.front(), internal_axes.back()); @@ -1160,9 +1154,15 @@ void LoopNest::reorderAxis(For* a, For* b) { } // Append the new statements to the root of the tree. - root->append_stmt(newInner); + if (before->body()->nstmts() == 0) { + // If the top level is now empty, eliminate it. + root->replace_stmt(before, newInner); + } else { + root->insert_stmt_after(newInner, before); + } + if (after) { - root->append_stmt(after); + root->insert_stmt_after(after, newInner); } } // namespace tensorexpr