[NNC] Fixes case where inlining wouldn't work because dim-size was 1. (#53254)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/52581

The git diff is absolutely atrocious since I also refactored the code to share stuff between `Load` and `FunctionCall`.

Biggest questions I have about this diff are:

1. The asserts I added. From my understanding it's not possible to have a constant index in `Store` that's non-zero, since `Store` always creates a new buffer. Perhaps the user can write this kind of incorrect code, though, so perhaps I should just check for it and not assert it?

2. I don't think(?) I need to do any special handling for `index_vars`, but wasn't totally able to track the logic there.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/53254

Reviewed By: albanD

Differential Revision: D26991064

Pulled By: Chillee

fbshipit-source-id: 0bcd612d5f4b031c0b34e68a72d9c8d12d118be8
This commit is contained in:
Horace He 2021-03-11 20:51:52 -08:00 committed by Facebook GitHub Bot
parent ce670238ba
commit d4602b7e45
2 changed files with 79 additions and 70 deletions

View file

@ -3755,6 +3755,28 @@ TEST(LoopNest, CompoundTensorSimple) {
assertAllEqual(a_data, a_ref);
}
TEST(LoopNest, InlineConstantIndex) {
KernelScope kernel_scope;
const int N = 10;
Placeholder x_buf("a", kFloat, {1, N, 1});
Tensor* y = Compute(
"f",
{{1, "m"}, {N, "n"}, {1, "o"}},
[&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& o) {
return x_buf.load(m, n, o);
});
Tensor* z = Compute(
"f",
{{1, "m"}, {N, "n"}, {1, "o"}},
[&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& o) {
return y->call(m, n, o);
});
LoopNest l({z});
l.simplify();
ASSERT_TRUE(l.computeInline(y->buf()));
}
TEST(LoopNest, CompoundTensorUsed) {
KernelScope kernel_scope;

View file

@ -532,16 +532,67 @@ class FunctionInliner : public IRMutator {
producer_(producer),
outputs_(std::move(outputs)) {
for (auto* i : producer->indices()) {
const Var* index_var = dynamic_cast<const Var*>(i);
if (index_var == nullptr) {
if (auto index_var = dynamic_cast<const Var*>(i)) {
index_vars_.insert(index_var);
producer_index_vars_.push_back(index_var);
} else if (dynamic_cast<const IntImm*>(i) != nullptr) {
// If the index can be a constant, then that dimension must have size 1
// (since we don't support in-place writes). Resolves issue 52581.
TORCH_INTERNAL_ASSERT(
dynamic_cast<const IntImm*>(i)->value() == 0,
"Constant index impression should always be zero");
producer_index_vars_.push_back(nullptr);
} else {
throw std::logic_error("cannot inline Buf with compound indices");
}
index_vars_.insert(index_var);
producer_index_vars_.push_back(index_var);
}
}
private:
const Expr* mutate_loads(const Buf* buf, std::vector<const Expr*> dims) {
std::vector<const Var*> index_vars;
TORCH_INTERNAL_ASSERT(buf->ndim() == producer_index_vars_.size());
for (size_t i = 0; i < buf->ndim(); i++) {
const Var* func_callee_arg = producer_index_vars_.at(i);
const Expr* func_caller_param = dims.at(i);
if (func_callee_arg == nullptr) {
TORCH_INTERNAL_ASSERT(
dynamic_cast<const IntImm*>(func_caller_param) != nullptr &&
dynamic_cast<const IntImm*>(func_caller_param)->value() == 0,
"We are implicitly assuming that if you have an index of 0, that must also be inlined into an index of 0");
continue;
}
if (func_callee_arg == nullptr)
continue;
auto iter = inline_mapping_.find(func_callee_arg);
if (iter != inline_mapping_.end()) {
throw std::runtime_error(
"Duplicated variables: " + func_callee_arg->name_hint());
}
// Add a mapping for each function parameter to it's source name.
inline_mapping_[func_callee_arg] = func_caller_param;
index_vars.push_back(func_callee_arg);
}
// Call the actual replacement.
const Expr* body = producer_->value();
const Expr* result = body->accept_mutator(this);
// Remove the mappings we created for this function parameters.
for (auto* v : index_vars) {
for (auto& pair : random_bindings_) {
if (pair.second.erase(v)) {
const Expr* inlined = inline_mapping_[v];
for (auto* nv : VarFinder::find(inlined)) {
pair.second.insert(nv);
}
}
}
inline_mapping_.erase(v);
}
return result;
}
// For the target function, insert the caller/callee pair into the replacement
// mapping.
const Expr* mutate(const FunctionCall* v) override {
@ -555,39 +606,7 @@ class FunctionInliner : public IRMutator {
throw malformed_input(
"Placeholder indexed access is inconsistent with its rank", v);
}
std::vector<const Var*> index_vars;
TORCH_INTERNAL_ASSERT(buf->ndim() == producer_index_vars_.size());
for (size_t i = 0; i < buf->ndim(); i++) {
const Var* func_callee_arg = producer_index_vars_.at(i);
const Expr* func_caller_param = v->param(i);
auto iter = inline_mapping_.find(func_callee_arg);
if (iter != inline_mapping_.end()) {
throw std::runtime_error(
"Duplicated variables: " + func_callee_arg->name_hint());
}
// Add a mapping for each function parameter to it's source name.
inline_mapping_[func_callee_arg] = func_caller_param;
index_vars.push_back(func_callee_arg);
}
// Call the actual replacement.
const Expr* body = producer_->value();
const Expr* result = body->accept_mutator(this);
// Remove the mappings we created for this function parameters.
for (auto* v : index_vars) {
for (auto& pair : random_bindings_) {
if (pair.second.erase(v)) {
const Expr* inlined = inline_mapping_[v];
for (auto* nv : VarFinder::find(inlined)) {
pair.second.insert(nv);
}
}
}
inline_mapping_.erase(v);
}
return result;
return mutate_loads(buf, v->params());
}
const Expr* mutate(const Load* v) override {
@ -600,39 +619,7 @@ class FunctionInliner : public IRMutator {
throw malformed_input(
"Placeholder indexed access is inconsistent with its rank", v);
}
std::vector<const Var*> index_vars;
TORCH_INTERNAL_ASSERT(buf->ndim() == producer_index_vars_.size());
for (size_t i = 0; i < buf->ndim(); i++) {
const Var* func_callee_arg = producer_index_vars_.at(i);
const Expr* func_caller_param = v->indices()[i];
auto iter = inline_mapping_.find(func_callee_arg);
if (iter != inline_mapping_.end()) {
throw std::runtime_error(
"Duplicated variables: " + func_callee_arg->name_hint());
}
// Add a mapping for each function parameter to it's source name.
inline_mapping_[func_callee_arg] = func_caller_param;
index_vars.push_back(func_callee_arg);
}
// Call the actual replacement.
const Expr* body = producer_->value();
const Expr* result = body->accept_mutator(this);
// Remove the mappings we created for this function parameters.
for (auto* v : index_vars) {
for (auto& pair : random_bindings_) {
if (pair.second.erase(v)) {
const Expr* inlined = inline_mapping_[v];
for (auto* nv : VarFinder::find(inlined)) {
pair.second.insert(nv);
}
}
}
inline_mapping_.erase(v);
}
return result;
return mutate_loads(buf, v->indices());
}
// Replace the target variable with the caller expressions.