[SR] Quick hack to eliminate no-op slice (#75774)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75774

`list[0:]` is a no-op. This should really be eliminated on the modeling side, implement as a graph pass for now until we can get this into prod models.

Test Plan: New unit tests

Reviewed By: navahgar

Differential Revision: D35632947

fbshipit-source-id: 0c564193c35039130e99172e0185e124ea24f62d
(cherry picked from commit e01d5273185e39a563c7acb15662d9c1549d4b58)
This commit is contained in:
Mike Iovine 2022-05-03 12:19:44 -07:00 committed by PyTorch MergeBot
parent 4537ac11db
commit b02b3f25db
4 changed files with 65 additions and 0 deletions

View file

@ -1608,3 +1608,28 @@ TEST(UseSplitAndSqueeze, Fusion) {
EXPECT_FALSE(hasNodeWithKind(graph, "aten::squeeze"));
EXPECT_FALSE(hasNodeWithKind(graph, "prim::ListUnpack"));
}
TEST(EliminateNoOpSlice, IntegerStart) {
const auto src = R"JIT(
def forward(self, x: List[int]) -> List[int]:
return x[0:]
)JIT";
torch::jit::Module mod("m");
mod.define(src);
auto graph = mod.get_method("forward").graph();
EXPECT_TRUE(hasNodeWithKind(graph, "aten::slice"));
EliminateNoOpSlice(graph);
EXPECT_FALSE(hasNodeWithKind(graph, "aten::slice"));
}
TEST(EliminateNoOpSlice, NoneStart) {
const auto src = R"JIT(
def forward(self, x: List[int]) -> List[int]:
return x[:]
)JIT";
torch::jit::Module mod("m");
mod.define(src);
auto graph = mod.get_method("forward").graph();
EliminateNoOpSlice(graph);
EXPECT_FALSE(hasNodeWithKind(graph, "aten::slice"));
}

View file

@ -140,6 +140,7 @@ void OptimizeGraph(
ConstantPropagation(graph);
RemoveTensorMutation(graph);
ConstantPropagation(graph);
EliminateNoOpSlice(graph);
EliminateDeadCode(graph);
FuseInferenceOpsForSparseNN(graph);
UseVariadicCat(graph);

View file

@ -1192,5 +1192,42 @@ C10_UNUSED void RemoveUnnecessaryEmbeddingBagOutputs(
fuse.runOnGraph(graph);
}
namespace {
bool isNoOpSlice(Node* node) {
DCHECK(node->kind() == aten::slice);
auto step = toIValue(node->input(3));
if (!step.has_value() || step->toInt() != 1) {
return false;
}
auto start = toIValue(node->input(1));
if (!start.has_value() || (start->isInt() && start->toInt() != 0)) {
return false;
}
auto end = toIValue(node->input(2));
// Could also look at list length, but most models that have this pattern are
// just doing list[0:], so it's not needed for now.
return end.has_value() && end->isNone();
}
} // namespace
void EliminateNoOpSlice(std::shared_ptr<Graph>& graph) {
DepthFirstGraphNodeIterator it(graph);
auto schema = torch::schema(
"aten::slice.t(t[] l, int? start=None, int? end=None, int step=1) -> t[]");
Node* node = nullptr;
std::vector<Node*> to_delete;
while ((node = it.next()) != nullptr) {
if (!node->matches(schema) || !isNoOpSlice(node)) {
continue;
}
node->output()->replaceAllUsesWith(node->input(0));
to_delete.push_back(node);
}
for (auto* node : to_delete) {
node->destroy();
}
}
} // namespace jit
} // namespace torch

View file

@ -63,6 +63,8 @@ TORCH_API void UseVariadicGroupedAccessor(const std::shared_ptr<Graph>& graph);
TORCH_API void EliminateExtraPermuteOps(std::shared_ptr<Graph>& graph);
TORCH_API void EliminateNoOpSlice(std::shared_ptr<Graph>& graph);
TORCH_API void UseSplitAndSqueeze(std::shared_ptr<Graph>& graph);
// [Remove unnecessary outputs]]