mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[StaticRuntime] Refactor StaticModule to pass in sample inputs (#69473)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/69473 This diff refactors StaticModule and its uses to pass in sample inputs. These inputs need to be passed into the constructor because they are need to perform TensorExpr fusion before other optimizations are performed on the input graph. ghstack-source-id: 146059041 Test Plan: buck run mode/opt //caffe2/caffe2/fb/predictor:pytorch_predictor_test Reviewed By: donaldong Differential Revision: D32320084 fbshipit-source-id: b8bd46d442be4cc90ca60f521e0416fdb88eea60
This commit is contained in:
parent
c4a6c7a436
commit
91da2d5fa1
3 changed files with 25 additions and 13 deletions
|
|
@ -69,7 +69,8 @@ class ModuleStaticRuntimeTestContext : public StaticRuntimeTestContext {
|
|||
}
|
||||
|
||||
StaticModule makeStaticModule(const StaticModuleOptions& opt) const override {
|
||||
return torch::jit::StaticModule(module_, /* is_frozen */ false, opt);
|
||||
return torch::jit::StaticModule(
|
||||
module_, /* is_frozen */ false, opt, /* sample_inputs */ {});
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
@ -91,7 +92,7 @@ class GraphStaticRuntimeContext : public StaticRuntimeTestContext {
|
|||
}
|
||||
|
||||
StaticModule makeStaticModule(const StaticModuleOptions& opt) const override {
|
||||
return StaticModule(graph_, opt);
|
||||
return StaticModule(graph_, opt, /* sample_inputs */ {});
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
|||
|
|
@ -170,7 +170,8 @@ bool mayContainAlias(
|
|||
|
||||
void PrepareGraphForStaticModule(
|
||||
std::shared_ptr<torch::jit::Graph> graph,
|
||||
const StaticModuleOptions& opts) {
|
||||
const StaticModuleOptions& opts,
|
||||
std::vector<IValue> sample_inputs) {
|
||||
TORCH_CHECK(canEnableStaticRuntime(graph));
|
||||
OptimizeGraph(graph, opts);
|
||||
}
|
||||
|
|
@ -178,7 +179,8 @@ void PrepareGraphForStaticModule(
|
|||
std::pair<std::shared_ptr<Graph>, c10::optional<Module>> PrepareForStaticModule(
|
||||
const torch::jit::Module& m,
|
||||
bool is_frozen,
|
||||
const StaticModuleOptions& opts) {
|
||||
const StaticModuleOptions& opts,
|
||||
std::vector<IValue> sample_inputs) {
|
||||
LOG(INFO) << "StaticModuleOptions: cleanup_activations "
|
||||
<< opts.cleanup_activations << ", enable_out_variant "
|
||||
<< opts.enable_out_variant << ", optimize_memory "
|
||||
|
|
@ -194,15 +196,16 @@ std::pair<std::shared_ptr<Graph>, c10::optional<Module>> PrepareForStaticModule(
|
|||
Method method = module.get_method("forward");
|
||||
auto graph = module.get_method("forward").graph();
|
||||
|
||||
PrepareGraphForStaticModule(graph, opts);
|
||||
PrepareGraphForStaticModule(graph, opts, sample_inputs);
|
||||
|
||||
return std::make_pair(graph, module);
|
||||
}
|
||||
|
||||
std::pair<std::shared_ptr<Graph>, c10::optional<Module>> PrepareForStaticModule(
|
||||
std::shared_ptr<torch::jit::Graph> graph,
|
||||
const StaticModuleOptions& opts) {
|
||||
PrepareGraphForStaticModule(graph, opts);
|
||||
const StaticModuleOptions& opts,
|
||||
std::vector<IValue> sample_inputs) {
|
||||
PrepareGraphForStaticModule(graph, opts, sample_inputs);
|
||||
return std::make_pair(graph, c10::nullopt);
|
||||
}
|
||||
|
||||
|
|
@ -423,14 +426,20 @@ std::vector<const Value*> ManagedTensorRanges::
|
|||
|
||||
StaticModule::StaticModule(
|
||||
std::shared_ptr<torch::jit::Graph> g,
|
||||
const StaticModuleOptions& opts)
|
||||
: StaticModule(PrepareForStaticModule(g->copy(), opts), opts) {}
|
||||
const StaticModuleOptions& opts,
|
||||
std::vector<IValue> sample_inputs)
|
||||
: StaticModule(
|
||||
PrepareForStaticModule(g->copy(), opts, sample_inputs),
|
||||
opts) {}
|
||||
|
||||
StaticModule::StaticModule(
|
||||
const torch::jit::Module& m,
|
||||
bool is_frozen,
|
||||
const StaticModuleOptions& opts)
|
||||
: StaticModule(PrepareForStaticModule(m, is_frozen, opts), opts) {}
|
||||
const StaticModuleOptions& opts,
|
||||
std::vector<IValue> sample_inputs)
|
||||
: StaticModule(
|
||||
PrepareForStaticModule(m, is_frozen, opts, sample_inputs),
|
||||
opts) {}
|
||||
|
||||
StaticModule::StaticModule(
|
||||
std::pair<std::shared_ptr<torch::jit::Graph>, c10::optional<Module>>
|
||||
|
|
|
|||
|
|
@ -218,12 +218,14 @@ class TORCH_API StaticModule {
|
|||
public:
|
||||
explicit StaticModule(
|
||||
std::shared_ptr<torch::jit::Graph> g,
|
||||
const StaticModuleOptions& opts = StaticModuleOptions());
|
||||
const StaticModuleOptions& opts = StaticModuleOptions(),
|
||||
std::vector<IValue> sample_inputs = {});
|
||||
|
||||
explicit StaticModule(
|
||||
const torch::jit::Module& m,
|
||||
bool is_frozen = false,
|
||||
const StaticModuleOptions& opts = StaticModuleOptions());
|
||||
const StaticModuleOptions& opts = StaticModuleOptions(),
|
||||
std::vector<IValue> sample_inputs = {});
|
||||
|
||||
typedef enum {
|
||||
CONSTANT_VALUE = -2, // VALUE nodes defined by prim::Constant
|
||||
|
|
|
|||
Loading…
Reference in a new issue