pytorch/test/cpp/jit/gtest.cpp
Richard Zou b14d6d730a Reuse KernelSpec for FusionGroups with equivalent graphs (#14541)
Summary:
Before this PR, loop unrolling + the graph fuser was creating multiple
FusionGroups with the same bodies (with different variable names) for
JIT LSTMs. Each FusionGroup got registered to a separate fusion key;
each key resulted in a different compilation for the same
specializations.

This PR makes it so that when registering FusionGroups with the fusion
compiler, the compiler first checks the KernelSpec cache to see if the
FusionGroup's graph exists already. If it does, then return the
corresponding KernelSpec's key to share compiled kernels.

In addition, graphs in the KernelSpec cache are canonicalized before
being cached. I added a flag to the canonicalize pass to remove unique
names of values.

This shortens the compile time for a JIT LSTM (seq_len of 100, loop
unroll factor of 8) from 5.3s to 2.3s. Most of this compile time is
running the graph fuser and/or fusion compiler; while this PR
makes it so that there is only one unique kernel in the forward pass,
there are a lot of different kernels (6) in the backward pass
(after loop unrolling) that should be investigated.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14541

Differential Revision: D13324487

Pulled By: zou3519

fbshipit-source-id: b841d82ed35a959b5cfc72db033bf5a7b42cc4fb
2018-12-13 07:54:35 -08:00

44 lines
928 B
C++

#include <gtest/gtest.h>
#include <test/cpp/jit/tests.h>
using namespace torch;
using namespace torch::jit;
#define JIT_TEST(name) \
TEST(JitTest, name) { \
test##name(); \
}
JIT_TEST(ADFormulas)
JIT_TEST(Attributes)
JIT_TEST(Blocks)
JIT_TEST(CodeTemplate)
JIT_TEST(ControlFlow)
JIT_TEST(CreateAutodiffSubgraphs)
JIT_TEST(CustomOperators)
JIT_TEST(Differentiate)
JIT_TEST(DifferentiateWithRequiresGrad)
JIT_TEST(DynamicDAG)
JIT_TEST(FromQualString)
JIT_TEST(InternedStrings)
JIT_TEST(IValue)
JIT_TEST(RegisterFusionCachesKernel)
JIT_TEST(SchemaParser)
JIT_TEST(TopologicalIndex)
JIT_TEST(TopologicalMove)
JIT_TEST(SubgraphUtils)
JIT_TEST(AliasAnalysis)
JIT_TEST(THNNConv)
JIT_TEST(ATenNativeBatchNorm)
#define JIT_TEST_CUDA(name) \
TEST(JitTest, name##_CUDA) { \
test##name(); \
}
JIT_TEST_CUDA(ArgumentSpec)
JIT_TEST_CUDA(Fusion)
JIT_TEST_CUDA(GraphExecutor)
JIT_TEST_CUDA(Interp)