diff --git a/benchmarks/static_runtime/test_utils.cc b/benchmarks/static_runtime/test_utils.cc index 31229913433..1f5ecaa77ed 100644 --- a/benchmarks/static_runtime/test_utils.cc +++ b/benchmarks/static_runtime/test_utils.cc @@ -69,7 +69,8 @@ class ModuleStaticRuntimeTestContext : public StaticRuntimeTestContext { } StaticModule makeStaticModule(const StaticModuleOptions& opt) const override { - return torch::jit::StaticModule(module_, /* is_frozen */ false, opt); + return torch::jit::StaticModule( + module_, /* is_frozen */ false, opt, /* sample_inputs */ {}); } private: @@ -91,7 +92,7 @@ class GraphStaticRuntimeContext : public StaticRuntimeTestContext { } StaticModule makeStaticModule(const StaticModuleOptions& opt) const override { - return StaticModule(graph_, opt); + return StaticModule(graph_, opt, /* sample_inputs */ {}); } private: diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp index aa98c71a66a..41a1eeaea41 100644 --- a/torch/csrc/jit/runtime/static/impl.cpp +++ b/torch/csrc/jit/runtime/static/impl.cpp @@ -170,7 +170,8 @@ bool mayContainAlias( void PrepareGraphForStaticModule( std::shared_ptr graph, - const StaticModuleOptions& opts) { + const StaticModuleOptions& opts, + std::vector sample_inputs) { TORCH_CHECK(canEnableStaticRuntime(graph)); OptimizeGraph(graph, opts); } @@ -178,7 +179,8 @@ void PrepareGraphForStaticModule( std::pair, c10::optional> PrepareForStaticModule( const torch::jit::Module& m, bool is_frozen, - const StaticModuleOptions& opts) { + const StaticModuleOptions& opts, + std::vector sample_inputs) { LOG(INFO) << "StaticModuleOptions: cleanup_activations " << opts.cleanup_activations << ", enable_out_variant " << opts.enable_out_variant << ", optimize_memory " @@ -194,15 +196,16 @@ std::pair, c10::optional> PrepareForStaticModule( Method method = module.get_method("forward"); auto graph = module.get_method("forward").graph(); - PrepareGraphForStaticModule(graph, opts); + PrepareGraphForStaticModule(graph, opts, sample_inputs); return std::make_pair(graph, module); } std::pair, c10::optional> PrepareForStaticModule( std::shared_ptr graph, - const StaticModuleOptions& opts) { - PrepareGraphForStaticModule(graph, opts); + const StaticModuleOptions& opts, + std::vector sample_inputs) { + PrepareGraphForStaticModule(graph, opts, sample_inputs); return std::make_pair(graph, c10::nullopt); } @@ -423,14 +426,20 @@ std::vector ManagedTensorRanges:: StaticModule::StaticModule( std::shared_ptr g, - const StaticModuleOptions& opts) - : StaticModule(PrepareForStaticModule(g->copy(), opts), opts) {} + const StaticModuleOptions& opts, + std::vector sample_inputs) + : StaticModule( + PrepareForStaticModule(g->copy(), opts, sample_inputs), + opts) {} StaticModule::StaticModule( const torch::jit::Module& m, bool is_frozen, - const StaticModuleOptions& opts) - : StaticModule(PrepareForStaticModule(m, is_frozen, opts), opts) {} + const StaticModuleOptions& opts, + std::vector sample_inputs) + : StaticModule( + PrepareForStaticModule(m, is_frozen, opts, sample_inputs), + opts) {} StaticModule::StaticModule( std::pair, c10::optional> diff --git a/torch/csrc/jit/runtime/static/impl.h b/torch/csrc/jit/runtime/static/impl.h index 1dd7fd01705..cb4f9072b45 100644 --- a/torch/csrc/jit/runtime/static/impl.h +++ b/torch/csrc/jit/runtime/static/impl.h @@ -218,12 +218,14 @@ class TORCH_API StaticModule { public: explicit StaticModule( std::shared_ptr g, - const StaticModuleOptions& opts = StaticModuleOptions()); + const StaticModuleOptions& opts = StaticModuleOptions(), + std::vector sample_inputs = {}); explicit StaticModule( const torch::jit::Module& m, bool is_frozen = false, - const StaticModuleOptions& opts = StaticModuleOptions()); + const StaticModuleOptions& opts = StaticModuleOptions(), + std::vector sample_inputs = {}); typedef enum { CONSTANT_VALUE = -2, // VALUE nodes defined by prim::Constant