diff --git a/benchmarks/static_runtime/test_static_module.cc b/benchmarks/static_runtime/test_static_module.cc index 793e08d4354..cc2c2e1abd5 100644 --- a/benchmarks/static_runtime/test_static_module.cc +++ b/benchmarks/static_runtime/test_static_module.cc @@ -1608,3 +1608,28 @@ TEST(UseSplitAndSqueeze, Fusion) { EXPECT_FALSE(hasNodeWithKind(graph, "aten::squeeze")); EXPECT_FALSE(hasNodeWithKind(graph, "prim::ListUnpack")); } + +TEST(EliminateNoOpSlice, IntegerStart) { + const auto src = R"JIT( + def forward(self, x: List[int]) -> List[int]: + return x[0:] + )JIT"; + torch::jit::Module mod("m"); + mod.define(src); + auto graph = mod.get_method("forward").graph(); + EXPECT_TRUE(hasNodeWithKind(graph, "aten::slice")); + EliminateNoOpSlice(graph); + EXPECT_FALSE(hasNodeWithKind(graph, "aten::slice")); +} + +TEST(EliminateNoOpSlice, NoneStart) { + const auto src = R"JIT( + def forward(self, x: List[int]) -> List[int]: + return x[:] + )JIT"; + torch::jit::Module mod("m"); + mod.define(src); + auto graph = mod.get_method("forward").graph(); + EliminateNoOpSlice(graph); + EXPECT_FALSE(hasNodeWithKind(graph, "aten::slice")); +} diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp index ab1506a9631..57ecd0795e8 100644 --- a/torch/csrc/jit/runtime/static/impl.cpp +++ b/torch/csrc/jit/runtime/static/impl.cpp @@ -140,6 +140,7 @@ void OptimizeGraph( ConstantPropagation(graph); RemoveTensorMutation(graph); ConstantPropagation(graph); + EliminateNoOpSlice(graph); EliminateDeadCode(graph); FuseInferenceOpsForSparseNN(graph); UseVariadicCat(graph); diff --git a/torch/csrc/jit/runtime/static/passes.cpp b/torch/csrc/jit/runtime/static/passes.cpp index 0a8eb920d2f..8de97cdd1f0 100644 --- a/torch/csrc/jit/runtime/static/passes.cpp +++ b/torch/csrc/jit/runtime/static/passes.cpp @@ -1192,5 +1192,42 @@ C10_UNUSED void RemoveUnnecessaryEmbeddingBagOutputs( fuse.runOnGraph(graph); } +namespace { +bool isNoOpSlice(Node* node) { + DCHECK(node->kind() == aten::slice); + auto step = toIValue(node->input(3)); + if (!step.has_value() || step->toInt() != 1) { + return false; + } + auto start = toIValue(node->input(1)); + if (!start.has_value() || (start->isInt() && start->toInt() != 0)) { + return false; + } + auto end = toIValue(node->input(2)); + // Could also look at list length, but most models that have this pattern are + // just doing list[0:], so it's not needed for now. + return end.has_value() && end->isNone(); +} +} // namespace + +void EliminateNoOpSlice(std::shared_ptr& graph) { + DepthFirstGraphNodeIterator it(graph); + auto schema = torch::schema( + "aten::slice.t(t[] l, int? start=None, int? end=None, int step=1) -> t[]"); + Node* node = nullptr; + std::vector to_delete; + while ((node = it.next()) != nullptr) { + if (!node->matches(schema) || !isNoOpSlice(node)) { + continue; + } + + node->output()->replaceAllUsesWith(node->input(0)); + to_delete.push_back(node); + } + for (auto* node : to_delete) { + node->destroy(); + } +} + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/runtime/static/passes.h b/torch/csrc/jit/runtime/static/passes.h index e1239b7db29..6672d004bd1 100644 --- a/torch/csrc/jit/runtime/static/passes.h +++ b/torch/csrc/jit/runtime/static/passes.h @@ -63,6 +63,8 @@ TORCH_API void UseVariadicGroupedAccessor(const std::shared_ptr& graph); TORCH_API void EliminateExtraPermuteOps(std::shared_ptr& graph); +TORCH_API void EliminateNoOpSlice(std::shared_ptr& graph); + TORCH_API void UseSplitAndSqueeze(std::shared_ptr& graph); // [Remove unnecessary outputs]]