diff --git a/test/jit/test_symbolic_shape_analysis.py b/test/jit/test_symbolic_shape_analysis.py index d65a91a420e..5205845f315 100644 --- a/test/jit/test_symbolic_shape_analysis.py +++ b/test/jit/test_symbolic_shape_analysis.py @@ -13,6 +13,13 @@ if __name__ == '__main__': # XXX: still in prototype class TestSymbolicShapeAnalysis(JitTestCase): + def setUp(self): + self.prev_symbolic_shapes_test_enabled = torch._C._jit_symbolic_shapes_test_mode_enabled() + torch._C._jit_set_symbolic_shapes_test_mode(True) + + def tearDown(self): + torch._C._jit_set_symbolic_shapes_test_mode(self.prev_symbolic_shapes_test_enabled) + def test_shape_analysis(self): @torch.jit.script def foo(x, y): diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index fe191bec9eb..1600a424d1c 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -213,6 +213,8 @@ def _jit_nvfuser_enabled() -> _bool: ... def _llvm_enabled() -> _bool: ... def _jit_override_can_fuse_on_cpu(override: _bool): ... def _jit_override_can_fuse_on_gpu(override: _bool): ... +def _jit_set_symbolic_shapes_test_mode(override: _bool): ... +def _jit_symbolic_shapes_test_mode_enabled() -> _bool: ... def _jit_set_texpr_fuser_enabled(enable: _bool): ... def _jit_set_te_must_use_llvm_cpu(use_llvm: _bool): ... def _jit_set_nvfuser_enabled(enable: _bool) -> _bool: ... diff --git a/torch/csrc/jit/passes/symbolic_shape_analysis.cpp b/torch/csrc/jit/passes/symbolic_shape_analysis.cpp index e1683d65c12..cfbcf086949 100644 --- a/torch/csrc/jit/passes/symbolic_shape_analysis.cpp +++ b/torch/csrc/jit/passes/symbolic_shape_analysis.cpp @@ -36,9 +36,21 @@ pointwise ops) - Supporting returning partially evaluated shape compute graph */ +static bool symbolic_shape_analysis_test_mode = false; + namespace torch { namespace jit { +bool setSymbolicShapeAnalysisTestMode(bool value) { + bool old_value = symbolic_shape_analysis_test_mode; + symbolic_shape_analysis_test_mode = value; + return old_value; +} + +bool symbolicShapeAnalysisTestModeEnabled() { + return symbolic_shape_analysis_test_mode; +} + // TODO: better registration mechanism std::mutex lock; std::unordered_map> operator_functions; @@ -79,7 +91,14 @@ struct SymbolicShapeAnalyzer { auto type = node_->input(i)->type(); if (auto tt = type->castRaw()) { c10::SymbolicShape symbolic_shapes = tt->symbolic_sizes(); - if (symbolic_shapes.isComplete()) { + + // for testing, we don't insert complete tensor shapes and rely on our + // partial evaluation pipeline to propagate information. + // this is a good proxy for our ability to propagate non-complete shape + // information. + + if (symbolic_shapes.isComplete() && + !symbolic_shape_analysis_test_mode) { replaceWithIValue( graph_->inputs().at(i), *tt->sizes().concrete_sizes()); continue; diff --git a/torch/csrc/jit/passes/symbolic_shape_analysis.h b/torch/csrc/jit/passes/symbolic_shape_analysis.h index 61798a1fb05..e1beefdf637 100644 --- a/torch/csrc/jit/passes/symbolic_shape_analysis.h +++ b/torch/csrc/jit/passes/symbolic_shape_analysis.h @@ -10,5 +10,12 @@ namespace jit { TORCH_API void PropagateShapesOnGraph(std::shared_ptr& graph); +// don't insert complete tensor shapes in shape compute graphs and instead +// rely on our partial evaluation pipeline to propagate information. +// this is a good proxy for our ability to propagate non-complete shape +// information. +TORCH_API bool setSymbolicShapeAnalysisTestMode(bool value); +TORCH_API bool symbolicShapeAnalysisTestModeEnabled(); + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 6e98dc18f2b..05a452fcbfe 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -179,6 +179,12 @@ void initJITBindings(PyObject* module) { .def("_jit_pass_propagate_shapes_on_graph", PropagateShapesOnGraph) .def("_jit_pass_onnx_function_substitution", ONNXFunctionCallSubstitution) .def("_jit_pass_integer_value_refinement", RefineIntegerValues) + .def( + "_jit_set_symbolic_shapes_test_mode", + &setSymbolicShapeAnalysisTestMode) + .def( + "_jit_symbolic_shapes_test_mode_enabled", + &symbolicShapeAnalysisTestModeEnabled) .def( "_jit_pass_onnx_fold_if", [](std::shared_ptr& graph) {