[JIT] Disable Complete Shape Inlining For Testing Purposes (#56966)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56966

This PR adds a toggle to shape analysis which won't inline complete tensor shapes as constants into the shape compute graph, which is a good stress test on the partial evaluation pipeline.

Test Plan: Imported from OSS

Reviewed By: bdhirsh

Differential Revision: D28444664

Pulled By: eellison

fbshipit-source-id: a62e424515a8837a4b596546efa93af5e8e61f10
This commit is contained in:
eellison 2021-05-27 17:52:46 -07:00 committed by Facebook GitHub Bot
parent f66fbb1e2e
commit d8cbba3ee2
5 changed files with 42 additions and 1 deletions

View file

@ -13,6 +13,13 @@ if __name__ == '__main__':
# XXX: still in prototype
class TestSymbolicShapeAnalysis(JitTestCase):
def setUp(self):
self.prev_symbolic_shapes_test_enabled = torch._C._jit_symbolic_shapes_test_mode_enabled()
torch._C._jit_set_symbolic_shapes_test_mode(True)
def tearDown(self):
torch._C._jit_set_symbolic_shapes_test_mode(self.prev_symbolic_shapes_test_enabled)
def test_shape_analysis(self):
@torch.jit.script
def foo(x, y):

View file

@ -213,6 +213,8 @@ def _jit_nvfuser_enabled() -> _bool: ...
def _llvm_enabled() -> _bool: ...
def _jit_override_can_fuse_on_cpu(override: _bool): ...
def _jit_override_can_fuse_on_gpu(override: _bool): ...
def _jit_set_symbolic_shapes_test_mode(override: _bool): ...
def _jit_symbolic_shapes_test_mode_enabled() -> _bool: ...
def _jit_set_texpr_fuser_enabled(enable: _bool): ...
def _jit_set_te_must_use_llvm_cpu(use_llvm: _bool): ...
def _jit_set_nvfuser_enabled(enable: _bool) -> _bool: ...

View file

@ -36,9 +36,21 @@ pointwise ops)
- Supporting returning partially evaluated shape compute graph
*/
static bool symbolic_shape_analysis_test_mode = false;
namespace torch {
namespace jit {
bool setSymbolicShapeAnalysisTestMode(bool value) {
bool old_value = symbolic_shape_analysis_test_mode;
symbolic_shape_analysis_test_mode = value;
return old_value;
}
bool symbolicShapeAnalysisTestModeEnabled() {
return symbolic_shape_analysis_test_mode;
}
// TODO: better registration mechanism
std::mutex lock;
std::unordered_map<std::string, std::shared_ptr<Graph>> operator_functions;
@ -79,7 +91,14 @@ struct SymbolicShapeAnalyzer {
auto type = node_->input(i)->type();
if (auto tt = type->castRaw<TensorType>()) {
c10::SymbolicShape symbolic_shapes = tt->symbolic_sizes();
if (symbolic_shapes.isComplete()) {
// for testing, we don't insert complete tensor shapes and rely on our
// partial evaluation pipeline to propagate information.
// this is a good proxy for our ability to propagate non-complete shape
// information.
if (symbolic_shapes.isComplete() &&
!symbolic_shape_analysis_test_mode) {
replaceWithIValue(
graph_->inputs().at(i), *tt->sizes().concrete_sizes());
continue;

View file

@ -10,5 +10,12 @@ namespace jit {
TORCH_API void PropagateShapesOnGraph(std::shared_ptr<Graph>& graph);
// don't insert complete tensor shapes in shape compute graphs and instead
// rely on our partial evaluation pipeline to propagate information.
// this is a good proxy for our ability to propagate non-complete shape
// information.
TORCH_API bool setSymbolicShapeAnalysisTestMode(bool value);
TORCH_API bool symbolicShapeAnalysisTestModeEnabled();
} // namespace jit
} // namespace torch

View file

@ -179,6 +179,12 @@ void initJITBindings(PyObject* module) {
.def("_jit_pass_propagate_shapes_on_graph", PropagateShapesOnGraph)
.def("_jit_pass_onnx_function_substitution", ONNXFunctionCallSubstitution)
.def("_jit_pass_integer_value_refinement", RefineIntegerValues)
.def(
"_jit_set_symbolic_shapes_test_mode",
&setSymbolicShapeAnalysisTestMode)
.def(
"_jit_symbolic_shapes_test_mode_enabled",
&symbolicShapeAnalysisTestModeEnabled)
.def(
"_jit_pass_onnx_fold_if",
[](std::shared_ptr<Graph>& graph) {