mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[SR] Quick hack to eliminate no-op slice (#75774)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/75774 `list[0:]` is a no-op. This should really be eliminated on the modeling side, implement as a graph pass for now until we can get this into prod models. Test Plan: New unit tests Reviewed By: navahgar Differential Revision: D35632947 fbshipit-source-id: 0c564193c35039130e99172e0185e124ea24f62d (cherry picked from commit e01d5273185e39a563c7acb15662d9c1549d4b58)
This commit is contained in:
parent
4537ac11db
commit
b02b3f25db
4 changed files with 65 additions and 0 deletions
|
|
@ -1608,3 +1608,28 @@ TEST(UseSplitAndSqueeze, Fusion) {
|
|||
EXPECT_FALSE(hasNodeWithKind(graph, "aten::squeeze"));
|
||||
EXPECT_FALSE(hasNodeWithKind(graph, "prim::ListUnpack"));
|
||||
}
|
||||
|
||||
TEST(EliminateNoOpSlice, IntegerStart) {
|
||||
const auto src = R"JIT(
|
||||
def forward(self, x: List[int]) -> List[int]:
|
||||
return x[0:]
|
||||
)JIT";
|
||||
torch::jit::Module mod("m");
|
||||
mod.define(src);
|
||||
auto graph = mod.get_method("forward").graph();
|
||||
EXPECT_TRUE(hasNodeWithKind(graph, "aten::slice"));
|
||||
EliminateNoOpSlice(graph);
|
||||
EXPECT_FALSE(hasNodeWithKind(graph, "aten::slice"));
|
||||
}
|
||||
|
||||
TEST(EliminateNoOpSlice, NoneStart) {
|
||||
const auto src = R"JIT(
|
||||
def forward(self, x: List[int]) -> List[int]:
|
||||
return x[:]
|
||||
)JIT";
|
||||
torch::jit::Module mod("m");
|
||||
mod.define(src);
|
||||
auto graph = mod.get_method("forward").graph();
|
||||
EliminateNoOpSlice(graph);
|
||||
EXPECT_FALSE(hasNodeWithKind(graph, "aten::slice"));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -140,6 +140,7 @@ void OptimizeGraph(
|
|||
ConstantPropagation(graph);
|
||||
RemoveTensorMutation(graph);
|
||||
ConstantPropagation(graph);
|
||||
EliminateNoOpSlice(graph);
|
||||
EliminateDeadCode(graph);
|
||||
FuseInferenceOpsForSparseNN(graph);
|
||||
UseVariadicCat(graph);
|
||||
|
|
|
|||
|
|
@ -1192,5 +1192,42 @@ C10_UNUSED void RemoveUnnecessaryEmbeddingBagOutputs(
|
|||
fuse.runOnGraph(graph);
|
||||
}
|
||||
|
||||
namespace {
|
||||
bool isNoOpSlice(Node* node) {
|
||||
DCHECK(node->kind() == aten::slice);
|
||||
auto step = toIValue(node->input(3));
|
||||
if (!step.has_value() || step->toInt() != 1) {
|
||||
return false;
|
||||
}
|
||||
auto start = toIValue(node->input(1));
|
||||
if (!start.has_value() || (start->isInt() && start->toInt() != 0)) {
|
||||
return false;
|
||||
}
|
||||
auto end = toIValue(node->input(2));
|
||||
// Could also look at list length, but most models that have this pattern are
|
||||
// just doing list[0:], so it's not needed for now.
|
||||
return end.has_value() && end->isNone();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void EliminateNoOpSlice(std::shared_ptr<Graph>& graph) {
|
||||
DepthFirstGraphNodeIterator it(graph);
|
||||
auto schema = torch::schema(
|
||||
"aten::slice.t(t[] l, int? start=None, int? end=None, int step=1) -> t[]");
|
||||
Node* node = nullptr;
|
||||
std::vector<Node*> to_delete;
|
||||
while ((node = it.next()) != nullptr) {
|
||||
if (!node->matches(schema) || !isNoOpSlice(node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
node->output()->replaceAllUsesWith(node->input(0));
|
||||
to_delete.push_back(node);
|
||||
}
|
||||
for (auto* node : to_delete) {
|
||||
node->destroy();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -63,6 +63,8 @@ TORCH_API void UseVariadicGroupedAccessor(const std::shared_ptr<Graph>& graph);
|
|||
|
||||
TORCH_API void EliminateExtraPermuteOps(std::shared_ptr<Graph>& graph);
|
||||
|
||||
TORCH_API void EliminateNoOpSlice(std::shared_ptr<Graph>& graph);
|
||||
|
||||
TORCH_API void UseSplitAndSqueeze(std::shared_ptr<Graph>& graph);
|
||||
|
||||
// [Remove unnecessary outputs]]
|
||||
|
|
|
|||
Loading…
Reference in a new issue