[StaticRuntime] Refactor StaticModule to pass in sample inputs (#69473)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69473

This diff refactors StaticModule and its uses to pass in sample inputs. These inputs need to be passed into the constructor because they are need to perform TensorExpr fusion before other optimizations are performed on the input graph.
ghstack-source-id: 146059041

Test Plan: buck run mode/opt //caffe2/caffe2/fb/predictor:pytorch_predictor_test

Reviewed By: donaldong

Differential Revision: D32320084

fbshipit-source-id: b8bd46d442be4cc90ca60f521e0416fdb88eea60
This commit is contained in:
Raghavan Raman 2021-12-21 11:18:50 -08:00 committed by Facebook GitHub Bot
parent c4a6c7a436
commit 91da2d5fa1
3 changed files with 25 additions and 13 deletions

View file

@ -69,7 +69,8 @@ class ModuleStaticRuntimeTestContext : public StaticRuntimeTestContext {
}
StaticModule makeStaticModule(const StaticModuleOptions& opt) const override {
return torch::jit::StaticModule(module_, /* is_frozen */ false, opt);
return torch::jit::StaticModule(
module_, /* is_frozen */ false, opt, /* sample_inputs */ {});
}
private:
@ -91,7 +92,7 @@ class GraphStaticRuntimeContext : public StaticRuntimeTestContext {
}
StaticModule makeStaticModule(const StaticModuleOptions& opt) const override {
return StaticModule(graph_, opt);
return StaticModule(graph_, opt, /* sample_inputs */ {});
}
private:

View file

@ -170,7 +170,8 @@ bool mayContainAlias(
void PrepareGraphForStaticModule(
std::shared_ptr<torch::jit::Graph> graph,
const StaticModuleOptions& opts) {
const StaticModuleOptions& opts,
std::vector<IValue> sample_inputs) {
TORCH_CHECK(canEnableStaticRuntime(graph));
OptimizeGraph(graph, opts);
}
@ -178,7 +179,8 @@ void PrepareGraphForStaticModule(
std::pair<std::shared_ptr<Graph>, c10::optional<Module>> PrepareForStaticModule(
const torch::jit::Module& m,
bool is_frozen,
const StaticModuleOptions& opts) {
const StaticModuleOptions& opts,
std::vector<IValue> sample_inputs) {
LOG(INFO) << "StaticModuleOptions: cleanup_activations "
<< opts.cleanup_activations << ", enable_out_variant "
<< opts.enable_out_variant << ", optimize_memory "
@ -194,15 +196,16 @@ std::pair<std::shared_ptr<Graph>, c10::optional<Module>> PrepareForStaticModule(
Method method = module.get_method("forward");
auto graph = module.get_method("forward").graph();
PrepareGraphForStaticModule(graph, opts);
PrepareGraphForStaticModule(graph, opts, sample_inputs);
return std::make_pair(graph, module);
}
std::pair<std::shared_ptr<Graph>, c10::optional<Module>> PrepareForStaticModule(
std::shared_ptr<torch::jit::Graph> graph,
const StaticModuleOptions& opts) {
PrepareGraphForStaticModule(graph, opts);
const StaticModuleOptions& opts,
std::vector<IValue> sample_inputs) {
PrepareGraphForStaticModule(graph, opts, sample_inputs);
return std::make_pair(graph, c10::nullopt);
}
@ -423,14 +426,20 @@ std::vector<const Value*> ManagedTensorRanges::
StaticModule::StaticModule(
std::shared_ptr<torch::jit::Graph> g,
const StaticModuleOptions& opts)
: StaticModule(PrepareForStaticModule(g->copy(), opts), opts) {}
const StaticModuleOptions& opts,
std::vector<IValue> sample_inputs)
: StaticModule(
PrepareForStaticModule(g->copy(), opts, sample_inputs),
opts) {}
StaticModule::StaticModule(
const torch::jit::Module& m,
bool is_frozen,
const StaticModuleOptions& opts)
: StaticModule(PrepareForStaticModule(m, is_frozen, opts), opts) {}
const StaticModuleOptions& opts,
std::vector<IValue> sample_inputs)
: StaticModule(
PrepareForStaticModule(m, is_frozen, opts, sample_inputs),
opts) {}
StaticModule::StaticModule(
std::pair<std::shared_ptr<torch::jit::Graph>, c10::optional<Module>>

View file

@ -218,12 +218,14 @@ class TORCH_API StaticModule {
public:
explicit StaticModule(
std::shared_ptr<torch::jit::Graph> g,
const StaticModuleOptions& opts = StaticModuleOptions());
const StaticModuleOptions& opts = StaticModuleOptions(),
std::vector<IValue> sample_inputs = {});
explicit StaticModule(
const torch::jit::Module& m,
bool is_frozen = false,
const StaticModuleOptions& opts = StaticModuleOptions());
const StaticModuleOptions& opts = StaticModuleOptions(),
std::vector<IValue> sample_inputs = {});
typedef enum {
CONSTANT_VALUE = -2, // VALUE nodes defined by prim::Constant