mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[JIT] Disable Complete Shape Inlining For Testing Purposes (#56966)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/56966 This PR adds a toggle to shape analysis which won't inline complete tensor shapes as constants into the shape compute graph, which is a good stress test on the partial evaluation pipeline. Test Plan: Imported from OSS Reviewed By: bdhirsh Differential Revision: D28444664 Pulled By: eellison fbshipit-source-id: a62e424515a8837a4b596546efa93af5e8e61f10
This commit is contained in:
parent
f66fbb1e2e
commit
d8cbba3ee2
5 changed files with 42 additions and 1 deletions
|
|
@ -13,6 +13,13 @@ if __name__ == '__main__':
|
|||
|
||||
# XXX: still in prototype
|
||||
class TestSymbolicShapeAnalysis(JitTestCase):
|
||||
def setUp(self):
|
||||
self.prev_symbolic_shapes_test_enabled = torch._C._jit_symbolic_shapes_test_mode_enabled()
|
||||
torch._C._jit_set_symbolic_shapes_test_mode(True)
|
||||
|
||||
def tearDown(self):
|
||||
torch._C._jit_set_symbolic_shapes_test_mode(self.prev_symbolic_shapes_test_enabled)
|
||||
|
||||
def test_shape_analysis(self):
|
||||
@torch.jit.script
|
||||
def foo(x, y):
|
||||
|
|
|
|||
|
|
@ -213,6 +213,8 @@ def _jit_nvfuser_enabled() -> _bool: ...
|
|||
def _llvm_enabled() -> _bool: ...
|
||||
def _jit_override_can_fuse_on_cpu(override: _bool): ...
|
||||
def _jit_override_can_fuse_on_gpu(override: _bool): ...
|
||||
def _jit_set_symbolic_shapes_test_mode(override: _bool): ...
|
||||
def _jit_symbolic_shapes_test_mode_enabled() -> _bool: ...
|
||||
def _jit_set_texpr_fuser_enabled(enable: _bool): ...
|
||||
def _jit_set_te_must_use_llvm_cpu(use_llvm: _bool): ...
|
||||
def _jit_set_nvfuser_enabled(enable: _bool) -> _bool: ...
|
||||
|
|
|
|||
|
|
@ -36,9 +36,21 @@ pointwise ops)
|
|||
- Supporting returning partially evaluated shape compute graph
|
||||
*/
|
||||
|
||||
static bool symbolic_shape_analysis_test_mode = false;
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
bool setSymbolicShapeAnalysisTestMode(bool value) {
|
||||
bool old_value = symbolic_shape_analysis_test_mode;
|
||||
symbolic_shape_analysis_test_mode = value;
|
||||
return old_value;
|
||||
}
|
||||
|
||||
bool symbolicShapeAnalysisTestModeEnabled() {
|
||||
return symbolic_shape_analysis_test_mode;
|
||||
}
|
||||
|
||||
// TODO: better registration mechanism
|
||||
std::mutex lock;
|
||||
std::unordered_map<std::string, std::shared_ptr<Graph>> operator_functions;
|
||||
|
|
@ -79,7 +91,14 @@ struct SymbolicShapeAnalyzer {
|
|||
auto type = node_->input(i)->type();
|
||||
if (auto tt = type->castRaw<TensorType>()) {
|
||||
c10::SymbolicShape symbolic_shapes = tt->symbolic_sizes();
|
||||
if (symbolic_shapes.isComplete()) {
|
||||
|
||||
// for testing, we don't insert complete tensor shapes and rely on our
|
||||
// partial evaluation pipeline to propagate information.
|
||||
// this is a good proxy for our ability to propagate non-complete shape
|
||||
// information.
|
||||
|
||||
if (symbolic_shapes.isComplete() &&
|
||||
!symbolic_shape_analysis_test_mode) {
|
||||
replaceWithIValue(
|
||||
graph_->inputs().at(i), *tt->sizes().concrete_sizes());
|
||||
continue;
|
||||
|
|
|
|||
|
|
@ -10,5 +10,12 @@ namespace jit {
|
|||
|
||||
TORCH_API void PropagateShapesOnGraph(std::shared_ptr<Graph>& graph);
|
||||
|
||||
// don't insert complete tensor shapes in shape compute graphs and instead
|
||||
// rely on our partial evaluation pipeline to propagate information.
|
||||
// this is a good proxy for our ability to propagate non-complete shape
|
||||
// information.
|
||||
TORCH_API bool setSymbolicShapeAnalysisTestMode(bool value);
|
||||
TORCH_API bool symbolicShapeAnalysisTestModeEnabled();
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -179,6 +179,12 @@ void initJITBindings(PyObject* module) {
|
|||
.def("_jit_pass_propagate_shapes_on_graph", PropagateShapesOnGraph)
|
||||
.def("_jit_pass_onnx_function_substitution", ONNXFunctionCallSubstitution)
|
||||
.def("_jit_pass_integer_value_refinement", RefineIntegerValues)
|
||||
.def(
|
||||
"_jit_set_symbolic_shapes_test_mode",
|
||||
&setSymbolicShapeAnalysisTestMode)
|
||||
.def(
|
||||
"_jit_symbolic_shapes_test_mode_enabled",
|
||||
&symbolicShapeAnalysisTestModeEnabled)
|
||||
.def(
|
||||
"_jit_pass_onnx_fold_if",
|
||||
[](std::shared_ptr<Graph>& graph) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue