mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[NNC] Fixes case where inlining wouldn't work because dim-size was 1. (#53254)
Summary: Fixes https://github.com/pytorch/pytorch/issues/52581 The git diff is absolutely atrocious since I also refactored the code to share stuff between `Load` and `FunctionCall`. Biggest questions I have about this diff are: 1. The asserts I added. From my understanding it's not possible to have a constant index in `Store` that's non-zero, since `Store` always creates a new buffer. Perhaps the user can write this kind of incorrect code, though, so perhaps I should just check for it and not assert it? 2. I don't think(?) I need to do any special handling for `index_vars`, but wasn't totally able to track the logic there. Pull Request resolved: https://github.com/pytorch/pytorch/pull/53254 Reviewed By: albanD Differential Revision: D26991064 Pulled By: Chillee fbshipit-source-id: 0bcd612d5f4b031c0b34e68a72d9c8d12d118be8
This commit is contained in:
parent
ce670238ba
commit
d4602b7e45
2 changed files with 79 additions and 70 deletions
|
|
@ -3755,6 +3755,28 @@ TEST(LoopNest, CompoundTensorSimple) {
|
|||
assertAllEqual(a_data, a_ref);
|
||||
}
|
||||
|
||||
TEST(LoopNest, InlineConstantIndex) {
|
||||
KernelScope kernel_scope;
|
||||
const int N = 10;
|
||||
Placeholder x_buf("a", kFloat, {1, N, 1});
|
||||
Tensor* y = Compute(
|
||||
"f",
|
||||
{{1, "m"}, {N, "n"}, {1, "o"}},
|
||||
[&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& o) {
|
||||
return x_buf.load(m, n, o);
|
||||
});
|
||||
Tensor* z = Compute(
|
||||
"f",
|
||||
{{1, "m"}, {N, "n"}, {1, "o"}},
|
||||
[&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& o) {
|
||||
return y->call(m, n, o);
|
||||
});
|
||||
|
||||
LoopNest l({z});
|
||||
l.simplify();
|
||||
ASSERT_TRUE(l.computeInline(y->buf()));
|
||||
}
|
||||
|
||||
TEST(LoopNest, CompoundTensorUsed) {
|
||||
KernelScope kernel_scope;
|
||||
|
||||
|
|
|
|||
|
|
@ -532,16 +532,67 @@ class FunctionInliner : public IRMutator {
|
|||
producer_(producer),
|
||||
outputs_(std::move(outputs)) {
|
||||
for (auto* i : producer->indices()) {
|
||||
const Var* index_var = dynamic_cast<const Var*>(i);
|
||||
if (index_var == nullptr) {
|
||||
if (auto index_var = dynamic_cast<const Var*>(i)) {
|
||||
index_vars_.insert(index_var);
|
||||
producer_index_vars_.push_back(index_var);
|
||||
} else if (dynamic_cast<const IntImm*>(i) != nullptr) {
|
||||
// If the index can be a constant, then that dimension must have size 1
|
||||
// (since we don't support in-place writes). Resolves issue 52581.
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
dynamic_cast<const IntImm*>(i)->value() == 0,
|
||||
"Constant index impression should always be zero");
|
||||
producer_index_vars_.push_back(nullptr);
|
||||
} else {
|
||||
throw std::logic_error("cannot inline Buf with compound indices");
|
||||
}
|
||||
index_vars_.insert(index_var);
|
||||
producer_index_vars_.push_back(index_var);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
const Expr* mutate_loads(const Buf* buf, std::vector<const Expr*> dims) {
|
||||
std::vector<const Var*> index_vars;
|
||||
TORCH_INTERNAL_ASSERT(buf->ndim() == producer_index_vars_.size());
|
||||
for (size_t i = 0; i < buf->ndim(); i++) {
|
||||
const Var* func_callee_arg = producer_index_vars_.at(i);
|
||||
const Expr* func_caller_param = dims.at(i);
|
||||
if (func_callee_arg == nullptr) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
dynamic_cast<const IntImm*>(func_caller_param) != nullptr &&
|
||||
dynamic_cast<const IntImm*>(func_caller_param)->value() == 0,
|
||||
"We are implicitly assuming that if you have an index of 0, that must also be inlined into an index of 0");
|
||||
continue;
|
||||
}
|
||||
if (func_callee_arg == nullptr)
|
||||
continue;
|
||||
auto iter = inline_mapping_.find(func_callee_arg);
|
||||
if (iter != inline_mapping_.end()) {
|
||||
throw std::runtime_error(
|
||||
"Duplicated variables: " + func_callee_arg->name_hint());
|
||||
}
|
||||
// Add a mapping for each function parameter to it's source name.
|
||||
inline_mapping_[func_callee_arg] = func_caller_param;
|
||||
index_vars.push_back(func_callee_arg);
|
||||
}
|
||||
|
||||
// Call the actual replacement.
|
||||
const Expr* body = producer_->value();
|
||||
const Expr* result = body->accept_mutator(this);
|
||||
|
||||
// Remove the mappings we created for this function parameters.
|
||||
for (auto* v : index_vars) {
|
||||
for (auto& pair : random_bindings_) {
|
||||
if (pair.second.erase(v)) {
|
||||
const Expr* inlined = inline_mapping_[v];
|
||||
for (auto* nv : VarFinder::find(inlined)) {
|
||||
pair.second.insert(nv);
|
||||
}
|
||||
}
|
||||
}
|
||||
inline_mapping_.erase(v);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// For the target function, insert the caller/callee pair into the replacement
|
||||
// mapping.
|
||||
const Expr* mutate(const FunctionCall* v) override {
|
||||
|
|
@ -555,39 +606,7 @@ class FunctionInliner : public IRMutator {
|
|||
throw malformed_input(
|
||||
"Placeholder indexed access is inconsistent with its rank", v);
|
||||
}
|
||||
|
||||
std::vector<const Var*> index_vars;
|
||||
TORCH_INTERNAL_ASSERT(buf->ndim() == producer_index_vars_.size());
|
||||
for (size_t i = 0; i < buf->ndim(); i++) {
|
||||
const Var* func_callee_arg = producer_index_vars_.at(i);
|
||||
const Expr* func_caller_param = v->param(i);
|
||||
auto iter = inline_mapping_.find(func_callee_arg);
|
||||
if (iter != inline_mapping_.end()) {
|
||||
throw std::runtime_error(
|
||||
"Duplicated variables: " + func_callee_arg->name_hint());
|
||||
}
|
||||
// Add a mapping for each function parameter to it's source name.
|
||||
inline_mapping_[func_callee_arg] = func_caller_param;
|
||||
index_vars.push_back(func_callee_arg);
|
||||
}
|
||||
|
||||
// Call the actual replacement.
|
||||
const Expr* body = producer_->value();
|
||||
const Expr* result = body->accept_mutator(this);
|
||||
|
||||
// Remove the mappings we created for this function parameters.
|
||||
for (auto* v : index_vars) {
|
||||
for (auto& pair : random_bindings_) {
|
||||
if (pair.second.erase(v)) {
|
||||
const Expr* inlined = inline_mapping_[v];
|
||||
for (auto* nv : VarFinder::find(inlined)) {
|
||||
pair.second.insert(nv);
|
||||
}
|
||||
}
|
||||
}
|
||||
inline_mapping_.erase(v);
|
||||
}
|
||||
return result;
|
||||
return mutate_loads(buf, v->params());
|
||||
}
|
||||
|
||||
const Expr* mutate(const Load* v) override {
|
||||
|
|
@ -600,39 +619,7 @@ class FunctionInliner : public IRMutator {
|
|||
throw malformed_input(
|
||||
"Placeholder indexed access is inconsistent with its rank", v);
|
||||
}
|
||||
|
||||
std::vector<const Var*> index_vars;
|
||||
TORCH_INTERNAL_ASSERT(buf->ndim() == producer_index_vars_.size());
|
||||
for (size_t i = 0; i < buf->ndim(); i++) {
|
||||
const Var* func_callee_arg = producer_index_vars_.at(i);
|
||||
const Expr* func_caller_param = v->indices()[i];
|
||||
auto iter = inline_mapping_.find(func_callee_arg);
|
||||
if (iter != inline_mapping_.end()) {
|
||||
throw std::runtime_error(
|
||||
"Duplicated variables: " + func_callee_arg->name_hint());
|
||||
}
|
||||
// Add a mapping for each function parameter to it's source name.
|
||||
inline_mapping_[func_callee_arg] = func_caller_param;
|
||||
index_vars.push_back(func_callee_arg);
|
||||
}
|
||||
|
||||
// Call the actual replacement.
|
||||
const Expr* body = producer_->value();
|
||||
const Expr* result = body->accept_mutator(this);
|
||||
|
||||
// Remove the mappings we created for this function parameters.
|
||||
for (auto* v : index_vars) {
|
||||
for (auto& pair : random_bindings_) {
|
||||
if (pair.second.erase(v)) {
|
||||
const Expr* inlined = inline_mapping_[v];
|
||||
for (auto* nv : VarFinder::find(inlined)) {
|
||||
pair.second.insert(nv);
|
||||
}
|
||||
}
|
||||
}
|
||||
inline_mapping_.erase(v);
|
||||
}
|
||||
return result;
|
||||
return mutate_loads(buf, v->indices());
|
||||
}
|
||||
|
||||
// Replace the target variable with the caller expressions.
|
||||
|
|
|
|||
Loading…
Reference in a new issue