[TensorExpr] fix a bug in ReorderAxis when there are trailing loops (#38841)

Summary:
Fixes a bug in reorder axis where we append the new reordered loops to the enclosing block, even if there were statements after it. e.g. with 3 Computes:
```
for (int m1 ...
  for (int n1 ...
    for (int k1 ...
      Body 1
for (int m2 ...
  for (int n2 ...
    for (int k2 ...
      Body 2
for (int m3 ...
  for (int n3 ...
    for (int k3 ...
      Body 3
```

If we reorder loops m2 and k2, we were also reordering the body statements like this:

```
for (int m1 ...
  for (int n1 ...
    for (int k1 ...
      Body 1
for (int m3 ...
  for (int n3 ...
    for (int k3 ...
      Body 3
for (int k2 ...
  for (int n2 ...
    for (int m2 ...
      Body 2
```

This is because we always append the new loops to their parent. This PR fixes the logic to replace the old loop root with the new loop, which keeps things consistent.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38841

Differential Revision: D21723670

Pulled By: nickgg

fbshipit-source-id: 1dee8bb153182fcaa2cabd948197577e8e80acd7
This commit is contained in:
Nick Gibson 2020-05-31 22:20:14 -07:00 committed by Facebook GitHub Bot
parent 68e62b9ab6
commit 5153cdbe87
4 changed files with 117 additions and 8 deletions

View file

@ -5,6 +5,7 @@
#include <unordered_map>
#include <test/cpp/tensorexpr/padded_buffer.h>
#include <torch/csrc/jit/tensorexpr/analysis.h>
#include <torch/csrc/jit/tensorexpr/bounds_inference.h>
#include <torch/csrc/jit/tensorexpr/buffer.h>
#include <torch/csrc/jit/tensorexpr/eval.h>
@ -1630,6 +1631,112 @@ void testLoopNestReorderLongStringFull() {
}
}
void testLoopNestReorderInternalLoopNest() {
KernelScope kernel_scope;
const int M = 4;
const int N = 5;
const int K = 6;
Buffer a_buf("a", kFloat, {M, N});
Buffer b_buf("b", kFloat, {N, K});
Buffer c_buf("c", kFloat, {M, N});
Buffer d_buf("d", kFloat, {M, K});
Tensor* x = Compute(
"x",
{{M, "m1"}, {N, "n1"}, {K, "k1"}},
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
return a_buf(m, n) * b_buf(n, k);
});
Tensor* y = Compute(
"y",
{{M, "m2"}, {N, "n2"}, {K, "k2"}},
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
return c_buf(m, n) * d_buf(m, k) + x->call(m, n, k);
});
Tensor* z = Compute(
"z",
{{M, "m3"}, {N, "n3"}, {K, "k3"}},
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
return x->call(m, n, k) + y->call(m, n, k);
});
LoopNest l({z});
For* a = nullptr;
For* b = nullptr;
auto fors = NodeFinder<For>::find(l.root_stmt());
for (auto* f : fors) {
if (f->var()->name_hint() == "m2") {
a = f;
} else if (f->var()->name_hint() == "k2") {
b = f;
}
}
l.reorderAxis(a, b);
l.prepareForCodegen();
Stmt* stmt = IRSimplifier::simplify(l.root_stmt());
std::ostringstream oss;
oss << *stmt;
// Check the IR we produced has the 3 nests in the right order, but k and m
// swapped in the middle.
const std::string& verification_pattern =
R"IR(
# CHECK: for (int m1
# CHECK: for (int n1
# CHECK: for (int k1
# CHECK: for (int k2
# CHECK: for (int n2
# CHECK: for (int m2
# CHECK: for (int m3
# CHECK: for (int n3
# CHECK: for (int k3)IR";
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
{
PaddedBuffer<float> a_v(M, N);
PaddedBuffer<float> b_v(N, K);
PaddedBuffer<float> c_v(M, N);
PaddedBuffer<float> d_v(M, K);
for (int i = 0; i < M; i++) {
for (int j = 0; j < N; j++) {
a_v(i, j) = i * i;
}
}
for (int i = 0; i < N; i++) {
for (int j = 0; j < K; j++) {
b_v(i, j) = j * j;
}
}
for (int i = 0; i < M; i++) {
for (int j = 0; j < N; j++) {
c_v(i, j) = i + j;
}
}
for (int i = 0; i < M; i++) {
for (int j = 0; j < K; j++) {
d_v(i, j) = i * j;
}
}
PaddedBuffer<float> z_v(M, N, K);
PaddedBuffer<float> z_ref(M, N, K);
for (int m = 0; m < M; m++) {
for (int n = 0; n < N; n++) {
for (int k = 0; k < K; k++) {
z_ref(m, n, k) = a_v(m, n) * b_v(n, k) * 2 + c_v(m, n) * d_v(m, k);
}
}
}
SimpleIREvaluator eval(stmt, a_buf, b_buf, c_buf, d_buf, z);
eval(a_v, b_v, c_v, d_v, z_v);
ExpectAllNear(z_v, z_ref, 1e-5);
}
}
void testOuterLoopVectorization() {
KernelScope kernel_scope;
Tensor* tensor = Compute(

View file

@ -181,6 +181,7 @@ namespace jit {
_(LoopNestReorderLongStringOfPreOrphans) \
_(LoopNestReorderLongStringOfPostOrphans) \
_(LoopNestReorderLongStringFull) \
_(LoopNestReorderInternalLoopNest) \
_(OuterLoopVectorization) \
_(Kernel_1) \
_(Kernel_2) \

View file

@ -45,6 +45,7 @@ class NodeFinder : public IRVisitor {
std::vector<Node*> nodes;
};
} // namespace tensorexpr
} // namespace jit
} // namespace torch

View file

@ -1145,12 +1145,6 @@ void LoopNest::reorderAxis(For* a, For* b) {
}
}
// If the top level is now empty, eliminate it.
if (before->body()->nstmts() == 0) {
root->remove_stmt(before);
before = nullptr;
}
// now we can actually reorder the chosen axes.
std::swap(internal_axes.front(), internal_axes.back());
@ -1160,9 +1154,15 @@ void LoopNest::reorderAxis(For* a, For* b) {
}
// Append the new statements to the root of the tree.
root->append_stmt(newInner);
if (before->body()->nstmts() == 0) {
// If the top level is now empty, eliminate it.
root->replace_stmt(before, newInner);
} else {
root->insert_stmt_after(newInner, before);
}
if (after) {
root->append_stmt(after);
root->insert_stmt_after(after, newInner);
}
} // namespace tensorexpr