From 91da2d5fa11c1a416420133038fcdc49a0eceb68 Mon Sep 17 00:00:00 2001 From: Raghavan Raman Date: Tue, 21 Dec 2021 11:18:50 -0800 Subject: [PATCH] [StaticRuntime] Refactor StaticModule to pass in sample inputs (#69473) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/69473 This diff refactors StaticModule and its uses to pass in sample inputs. These inputs need to be passed into the constructor because they are need to perform TensorExpr fusion before other optimizations are performed on the input graph. ghstack-source-id: 146059041 Test Plan: buck run mode/opt //caffe2/caffe2/fb/predictor:pytorch_predictor_test Reviewed By: donaldong Differential Revision: D32320084 fbshipit-source-id: b8bd46d442be4cc90ca60f521e0416fdb88eea60 --- benchmarks/static_runtime/test_utils.cc | 5 +++-- torch/csrc/jit/runtime/static/impl.cpp | 27 ++++++++++++++++--------- torch/csrc/jit/runtime/static/impl.h | 6 ++++-- 3 files changed, 25 insertions(+), 13 deletions(-) diff --git a/benchmarks/static_runtime/test_utils.cc b/benchmarks/static_runtime/test_utils.cc index 31229913433..1f5ecaa77ed 100644 --- a/benchmarks/static_runtime/test_utils.cc +++ b/benchmarks/static_runtime/test_utils.cc @@ -69,7 +69,8 @@ class ModuleStaticRuntimeTestContext : public StaticRuntimeTestContext { } StaticModule makeStaticModule(const StaticModuleOptions& opt) const override { - return torch::jit::StaticModule(module_, /* is_frozen */ false, opt); + return torch::jit::StaticModule( + module_, /* is_frozen */ false, opt, /* sample_inputs */ {}); } private: @@ -91,7 +92,7 @@ class GraphStaticRuntimeContext : public StaticRuntimeTestContext { } StaticModule makeStaticModule(const StaticModuleOptions& opt) const override { - return StaticModule(graph_, opt); + return StaticModule(graph_, opt, /* sample_inputs */ {}); } private: diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp index aa98c71a66a..41a1eeaea41 100644 --- a/torch/csrc/jit/runtime/static/impl.cpp +++ b/torch/csrc/jit/runtime/static/impl.cpp @@ -170,7 +170,8 @@ bool mayContainAlias( void PrepareGraphForStaticModule( std::shared_ptr graph, - const StaticModuleOptions& opts) { + const StaticModuleOptions& opts, + std::vector sample_inputs) { TORCH_CHECK(canEnableStaticRuntime(graph)); OptimizeGraph(graph, opts); } @@ -178,7 +179,8 @@ void PrepareGraphForStaticModule( std::pair, c10::optional> PrepareForStaticModule( const torch::jit::Module& m, bool is_frozen, - const StaticModuleOptions& opts) { + const StaticModuleOptions& opts, + std::vector sample_inputs) { LOG(INFO) << "StaticModuleOptions: cleanup_activations " << opts.cleanup_activations << ", enable_out_variant " << opts.enable_out_variant << ", optimize_memory " @@ -194,15 +196,16 @@ std::pair, c10::optional> PrepareForStaticModule( Method method = module.get_method("forward"); auto graph = module.get_method("forward").graph(); - PrepareGraphForStaticModule(graph, opts); + PrepareGraphForStaticModule(graph, opts, sample_inputs); return std::make_pair(graph, module); } std::pair, c10::optional> PrepareForStaticModule( std::shared_ptr graph, - const StaticModuleOptions& opts) { - PrepareGraphForStaticModule(graph, opts); + const StaticModuleOptions& opts, + std::vector sample_inputs) { + PrepareGraphForStaticModule(graph, opts, sample_inputs); return std::make_pair(graph, c10::nullopt); } @@ -423,14 +426,20 @@ std::vector ManagedTensorRanges:: StaticModule::StaticModule( std::shared_ptr g, - const StaticModuleOptions& opts) - : StaticModule(PrepareForStaticModule(g->copy(), opts), opts) {} + const StaticModuleOptions& opts, + std::vector sample_inputs) + : StaticModule( + PrepareForStaticModule(g->copy(), opts, sample_inputs), + opts) {} StaticModule::StaticModule( const torch::jit::Module& m, bool is_frozen, - const StaticModuleOptions& opts) - : StaticModule(PrepareForStaticModule(m, is_frozen, opts), opts) {} + const StaticModuleOptions& opts, + std::vector sample_inputs) + : StaticModule( + PrepareForStaticModule(m, is_frozen, opts, sample_inputs), + opts) {} StaticModule::StaticModule( std::pair, c10::optional> diff --git a/torch/csrc/jit/runtime/static/impl.h b/torch/csrc/jit/runtime/static/impl.h index 1dd7fd01705..cb4f9072b45 100644 --- a/torch/csrc/jit/runtime/static/impl.h +++ b/torch/csrc/jit/runtime/static/impl.h @@ -218,12 +218,14 @@ class TORCH_API StaticModule { public: explicit StaticModule( std::shared_ptr g, - const StaticModuleOptions& opts = StaticModuleOptions()); + const StaticModuleOptions& opts = StaticModuleOptions(), + std::vector sample_inputs = {}); explicit StaticModule( const torch::jit::Module& m, bool is_frozen = false, - const StaticModuleOptions& opts = StaticModuleOptions()); + const StaticModuleOptions& opts = StaticModuleOptions(), + std::vector sample_inputs = {}); typedef enum { CONSTANT_VALUE = -2, // VALUE nodes defined by prim::Constant