mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[TensorExpr] fix a bug in ReorderAxis when there are trailing loops (#38841)
Summary:
Fixes a bug in reorder axis where we append the new reordered loops to the enclosing block, even if there were statements after it. e.g. with 3 Computes:
```
for (int m1 ...
for (int n1 ...
for (int k1 ...
Body 1
for (int m2 ...
for (int n2 ...
for (int k2 ...
Body 2
for (int m3 ...
for (int n3 ...
for (int k3 ...
Body 3
```
If we reorder loops m2 and k2, we were also reordering the body statements like this:
```
for (int m1 ...
for (int n1 ...
for (int k1 ...
Body 1
for (int m3 ...
for (int n3 ...
for (int k3 ...
Body 3
for (int k2 ...
for (int n2 ...
for (int m2 ...
Body 2
```
This is because we always append the new loops to their parent. This PR fixes the logic to replace the old loop root with the new loop, which keeps things consistent.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38841
Differential Revision: D21723670
Pulled By: nickgg
fbshipit-source-id: 1dee8bb153182fcaa2cabd948197577e8e80acd7
This commit is contained in:
parent
68e62b9ab6
commit
5153cdbe87
4 changed files with 117 additions and 8 deletions
|
|
@ -5,6 +5,7 @@
|
|||
#include <unordered_map>
|
||||
|
||||
#include <test/cpp/tensorexpr/padded_buffer.h>
|
||||
#include <torch/csrc/jit/tensorexpr/analysis.h>
|
||||
#include <torch/csrc/jit/tensorexpr/bounds_inference.h>
|
||||
#include <torch/csrc/jit/tensorexpr/buffer.h>
|
||||
#include <torch/csrc/jit/tensorexpr/eval.h>
|
||||
|
|
@ -1630,6 +1631,112 @@ void testLoopNestReorderLongStringFull() {
|
|||
}
|
||||
}
|
||||
|
||||
void testLoopNestReorderInternalLoopNest() {
|
||||
KernelScope kernel_scope;
|
||||
const int M = 4;
|
||||
const int N = 5;
|
||||
const int K = 6;
|
||||
Buffer a_buf("a", kFloat, {M, N});
|
||||
Buffer b_buf("b", kFloat, {N, K});
|
||||
Buffer c_buf("c", kFloat, {M, N});
|
||||
Buffer d_buf("d", kFloat, {M, K});
|
||||
|
||||
Tensor* x = Compute(
|
||||
"x",
|
||||
{{M, "m1"}, {N, "n1"}, {K, "k1"}},
|
||||
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
|
||||
return a_buf(m, n) * b_buf(n, k);
|
||||
});
|
||||
Tensor* y = Compute(
|
||||
"y",
|
||||
{{M, "m2"}, {N, "n2"}, {K, "k2"}},
|
||||
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
|
||||
return c_buf(m, n) * d_buf(m, k) + x->call(m, n, k);
|
||||
});
|
||||
Tensor* z = Compute(
|
||||
"z",
|
||||
{{M, "m3"}, {N, "n3"}, {K, "k3"}},
|
||||
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
|
||||
return x->call(m, n, k) + y->call(m, n, k);
|
||||
});
|
||||
|
||||
LoopNest l({z});
|
||||
For* a = nullptr;
|
||||
For* b = nullptr;
|
||||
auto fors = NodeFinder<For>::find(l.root_stmt());
|
||||
for (auto* f : fors) {
|
||||
if (f->var()->name_hint() == "m2") {
|
||||
a = f;
|
||||
} else if (f->var()->name_hint() == "k2") {
|
||||
b = f;
|
||||
}
|
||||
}
|
||||
l.reorderAxis(a, b);
|
||||
|
||||
l.prepareForCodegen();
|
||||
Stmt* stmt = IRSimplifier::simplify(l.root_stmt());
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << *stmt;
|
||||
|
||||
// Check the IR we produced has the 3 nests in the right order, but k and m
|
||||
// swapped in the middle.
|
||||
const std::string& verification_pattern =
|
||||
R"IR(
|
||||
# CHECK: for (int m1
|
||||
# CHECK: for (int n1
|
||||
# CHECK: for (int k1
|
||||
# CHECK: for (int k2
|
||||
# CHECK: for (int n2
|
||||
# CHECK: for (int m2
|
||||
# CHECK: for (int m3
|
||||
# CHECK: for (int n3
|
||||
# CHECK: for (int k3)IR";
|
||||
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
||||
|
||||
{
|
||||
PaddedBuffer<float> a_v(M, N);
|
||||
PaddedBuffer<float> b_v(N, K);
|
||||
PaddedBuffer<float> c_v(M, N);
|
||||
PaddedBuffer<float> d_v(M, K);
|
||||
|
||||
for (int i = 0; i < M; i++) {
|
||||
for (int j = 0; j < N; j++) {
|
||||
a_v(i, j) = i * i;
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < N; i++) {
|
||||
for (int j = 0; j < K; j++) {
|
||||
b_v(i, j) = j * j;
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < M; i++) {
|
||||
for (int j = 0; j < N; j++) {
|
||||
c_v(i, j) = i + j;
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < M; i++) {
|
||||
for (int j = 0; j < K; j++) {
|
||||
d_v(i, j) = i * j;
|
||||
}
|
||||
}
|
||||
|
||||
PaddedBuffer<float> z_v(M, N, K);
|
||||
PaddedBuffer<float> z_ref(M, N, K);
|
||||
for (int m = 0; m < M; m++) {
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int k = 0; k < K; k++) {
|
||||
z_ref(m, n, k) = a_v(m, n) * b_v(n, k) * 2 + c_v(m, n) * d_v(m, k);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SimpleIREvaluator eval(stmt, a_buf, b_buf, c_buf, d_buf, z);
|
||||
eval(a_v, b_v, c_v, d_v, z_v);
|
||||
ExpectAllNear(z_v, z_ref, 1e-5);
|
||||
}
|
||||
}
|
||||
|
||||
void testOuterLoopVectorization() {
|
||||
KernelScope kernel_scope;
|
||||
Tensor* tensor = Compute(
|
||||
|
|
|
|||
|
|
@ -181,6 +181,7 @@ namespace jit {
|
|||
_(LoopNestReorderLongStringOfPreOrphans) \
|
||||
_(LoopNestReorderLongStringOfPostOrphans) \
|
||||
_(LoopNestReorderLongStringFull) \
|
||||
_(LoopNestReorderInternalLoopNest) \
|
||||
_(OuterLoopVectorization) \
|
||||
_(Kernel_1) \
|
||||
_(Kernel_2) \
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ class NodeFinder : public IRVisitor {
|
|||
|
||||
std::vector<Node*> nodes;
|
||||
};
|
||||
|
||||
} // namespace tensorexpr
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -1145,12 +1145,6 @@ void LoopNest::reorderAxis(For* a, For* b) {
|
|||
}
|
||||
}
|
||||
|
||||
// If the top level is now empty, eliminate it.
|
||||
if (before->body()->nstmts() == 0) {
|
||||
root->remove_stmt(before);
|
||||
before = nullptr;
|
||||
}
|
||||
|
||||
// now we can actually reorder the chosen axes.
|
||||
std::swap(internal_axes.front(), internal_axes.back());
|
||||
|
||||
|
|
@ -1160,9 +1154,15 @@ void LoopNest::reorderAxis(For* a, For* b) {
|
|||
}
|
||||
|
||||
// Append the new statements to the root of the tree.
|
||||
root->append_stmt(newInner);
|
||||
if (before->body()->nstmts() == 0) {
|
||||
// If the top level is now empty, eliminate it.
|
||||
root->replace_stmt(before, newInner);
|
||||
} else {
|
||||
root->insert_stmt_after(newInner, before);
|
||||
}
|
||||
|
||||
if (after) {
|
||||
root->append_stmt(after);
|
||||
root->insert_stmt_after(after, newInner);
|
||||
}
|
||||
} // namespace tensorexpr
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue