From d4602b7e459c4ad19cf1f59cd157cc4937f8ba72 Mon Sep 17 00:00:00 2001 From: Horace He Date: Thu, 11 Mar 2021 20:51:52 -0800 Subject: [PATCH] [NNC] Fixes case where inlining wouldn't work because dim-size was 1. (#53254) Summary: Fixes https://github.com/pytorch/pytorch/issues/52581 The git diff is absolutely atrocious since I also refactored the code to share stuff between `Load` and `FunctionCall`. Biggest questions I have about this diff are: 1. The asserts I added. From my understanding it's not possible to have a constant index in `Store` that's non-zero, since `Store` always creates a new buffer. Perhaps the user can write this kind of incorrect code, though, so perhaps I should just check for it and not assert it? 2. I don't think(?) I need to do any special handling for `index_vars`, but wasn't totally able to track the logic there. Pull Request resolved: https://github.com/pytorch/pytorch/pull/53254 Reviewed By: albanD Differential Revision: D26991064 Pulled By: Chillee fbshipit-source-id: 0bcd612d5f4b031c0b34e68a72d9c8d12d118be8 --- test/cpp/tensorexpr/test_loopnest.cpp | 22 +++++ torch/csrc/jit/tensorexpr/loopnest.cpp | 127 +++++++++++-------------- 2 files changed, 79 insertions(+), 70 deletions(-) diff --git a/test/cpp/tensorexpr/test_loopnest.cpp b/test/cpp/tensorexpr/test_loopnest.cpp index a34bde264f8..5924d442140 100644 --- a/test/cpp/tensorexpr/test_loopnest.cpp +++ b/test/cpp/tensorexpr/test_loopnest.cpp @@ -3755,6 +3755,28 @@ TEST(LoopNest, CompoundTensorSimple) { assertAllEqual(a_data, a_ref); } +TEST(LoopNest, InlineConstantIndex) { + KernelScope kernel_scope; + const int N = 10; + Placeholder x_buf("a", kFloat, {1, N, 1}); + Tensor* y = Compute( + "f", + {{1, "m"}, {N, "n"}, {1, "o"}}, + [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& o) { + return x_buf.load(m, n, o); + }); + Tensor* z = Compute( + "f", + {{1, "m"}, {N, "n"}, {1, "o"}}, + [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& o) { + return y->call(m, n, o); + }); + + LoopNest l({z}); + l.simplify(); + ASSERT_TRUE(l.computeInline(y->buf())); +} + TEST(LoopNest, CompoundTensorUsed) { KernelScope kernel_scope; diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp index 3dea80130b7..09308698b64 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.cpp +++ b/torch/csrc/jit/tensorexpr/loopnest.cpp @@ -532,16 +532,67 @@ class FunctionInliner : public IRMutator { producer_(producer), outputs_(std::move(outputs)) { for (auto* i : producer->indices()) { - const Var* index_var = dynamic_cast(i); - if (index_var == nullptr) { + if (auto index_var = dynamic_cast(i)) { + index_vars_.insert(index_var); + producer_index_vars_.push_back(index_var); + } else if (dynamic_cast(i) != nullptr) { + // If the index can be a constant, then that dimension must have size 1 + // (since we don't support in-place writes). Resolves issue 52581. + TORCH_INTERNAL_ASSERT( + dynamic_cast(i)->value() == 0, + "Constant index impression should always be zero"); + producer_index_vars_.push_back(nullptr); + } else { throw std::logic_error("cannot inline Buf with compound indices"); } - index_vars_.insert(index_var); - producer_index_vars_.push_back(index_var); } } private: + const Expr* mutate_loads(const Buf* buf, std::vector dims) { + std::vector index_vars; + TORCH_INTERNAL_ASSERT(buf->ndim() == producer_index_vars_.size()); + for (size_t i = 0; i < buf->ndim(); i++) { + const Var* func_callee_arg = producer_index_vars_.at(i); + const Expr* func_caller_param = dims.at(i); + if (func_callee_arg == nullptr) { + TORCH_INTERNAL_ASSERT( + dynamic_cast(func_caller_param) != nullptr && + dynamic_cast(func_caller_param)->value() == 0, + "We are implicitly assuming that if you have an index of 0, that must also be inlined into an index of 0"); + continue; + } + if (func_callee_arg == nullptr) + continue; + auto iter = inline_mapping_.find(func_callee_arg); + if (iter != inline_mapping_.end()) { + throw std::runtime_error( + "Duplicated variables: " + func_callee_arg->name_hint()); + } + // Add a mapping for each function parameter to it's source name. + inline_mapping_[func_callee_arg] = func_caller_param; + index_vars.push_back(func_callee_arg); + } + + // Call the actual replacement. + const Expr* body = producer_->value(); + const Expr* result = body->accept_mutator(this); + + // Remove the mappings we created for this function parameters. + for (auto* v : index_vars) { + for (auto& pair : random_bindings_) { + if (pair.second.erase(v)) { + const Expr* inlined = inline_mapping_[v]; + for (auto* nv : VarFinder::find(inlined)) { + pair.second.insert(nv); + } + } + } + inline_mapping_.erase(v); + } + return result; + } + // For the target function, insert the caller/callee pair into the replacement // mapping. const Expr* mutate(const FunctionCall* v) override { @@ -555,39 +606,7 @@ class FunctionInliner : public IRMutator { throw malformed_input( "Placeholder indexed access is inconsistent with its rank", v); } - - std::vector index_vars; - TORCH_INTERNAL_ASSERT(buf->ndim() == producer_index_vars_.size()); - for (size_t i = 0; i < buf->ndim(); i++) { - const Var* func_callee_arg = producer_index_vars_.at(i); - const Expr* func_caller_param = v->param(i); - auto iter = inline_mapping_.find(func_callee_arg); - if (iter != inline_mapping_.end()) { - throw std::runtime_error( - "Duplicated variables: " + func_callee_arg->name_hint()); - } - // Add a mapping for each function parameter to it's source name. - inline_mapping_[func_callee_arg] = func_caller_param; - index_vars.push_back(func_callee_arg); - } - - // Call the actual replacement. - const Expr* body = producer_->value(); - const Expr* result = body->accept_mutator(this); - - // Remove the mappings we created for this function parameters. - for (auto* v : index_vars) { - for (auto& pair : random_bindings_) { - if (pair.second.erase(v)) { - const Expr* inlined = inline_mapping_[v]; - for (auto* nv : VarFinder::find(inlined)) { - pair.second.insert(nv); - } - } - } - inline_mapping_.erase(v); - } - return result; + return mutate_loads(buf, v->params()); } const Expr* mutate(const Load* v) override { @@ -600,39 +619,7 @@ class FunctionInliner : public IRMutator { throw malformed_input( "Placeholder indexed access is inconsistent with its rank", v); } - - std::vector index_vars; - TORCH_INTERNAL_ASSERT(buf->ndim() == producer_index_vars_.size()); - for (size_t i = 0; i < buf->ndim(); i++) { - const Var* func_callee_arg = producer_index_vars_.at(i); - const Expr* func_caller_param = v->indices()[i]; - auto iter = inline_mapping_.find(func_callee_arg); - if (iter != inline_mapping_.end()) { - throw std::runtime_error( - "Duplicated variables: " + func_callee_arg->name_hint()); - } - // Add a mapping for each function parameter to it's source name. - inline_mapping_[func_callee_arg] = func_caller_param; - index_vars.push_back(func_callee_arg); - } - - // Call the actual replacement. - const Expr* body = producer_->value(); - const Expr* result = body->accept_mutator(this); - - // Remove the mappings we created for this function parameters. - for (auto* v : index_vars) { - for (auto& pair : random_bindings_) { - if (pair.second.erase(v)) { - const Expr* inlined = inline_mapping_[v]; - for (auto* nv : VarFinder::find(inlined)) { - pair.second.insert(nv); - } - } - } - inline_mapping_.erase(v); - } - return result; + return mutate_loads(buf, v->indices()); } // Replace the target variable with the caller expressions.