Enable prepacking in subgraph (#5433)

Prepacking in subgraph is not supported currently. We see more and more models with subgraph, which has MatMul, MatMulInteger and other ops. Prepacking can speed up those models significantly.
This commit is contained in:
Yufeng Li 2020-10-26 22:22:31 -07:00 committed by GitHub
parent 564da960ce
commit 30cdc74bc0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 218 additions and 61 deletions

View file

@ -121,7 +121,7 @@ class Node {
/** Gets the Node's Node::Type. */
Node::Type NodeType() const noexcept { return node_type_; }
/** Gets the opset version that the Node's operator was first defined in.
@returns Opset version. If -1 the Node's operator has not been set.
@remarks Prefer over Op()->SinceVersion() as Op() is disabled in a minimal build
@ -1029,13 +1029,12 @@ class Graph {
/** Returns true if the name is for a value that is coming from outer scope */
bool IsOuterScopeValue(const std::string& name) const {
#if !defined(ORT_MINIMAL_BUILD)
return resolve_context_.outer_scope_node_args.find(name) != resolve_context_.outer_scope_node_args.cend();
#else
// we shouldn't have code that calls this in a minimal build
ORT_UNUSED_PARAMETER(name);
ORT_THROW("Internal error. Outer scope value lookup is not currently supported in a minimal build.");
#endif
if (!parent_node_) return false;
const auto& implicit_input_defs = parent_node_->ImplicitInputDefs();
return std::any_of(implicit_input_defs.cbegin(), implicit_input_defs.cend(),
[&name](const NodeArg* implicit_input) {
return implicit_input->Name() == name;
});
}
#if !defined(ORT_MINIMAL_BUILD)

View file

@ -234,36 +234,39 @@ void SessionState::CleanInitializedTensorsFromGraph() {
graph_.CleanAllInitializedTensors();
}
Status SessionState::PrepackInitializedConstantTensors() {
// calculate the use count of each value
std::unordered_map<std::string, size_t> node_arg_use_count;
for (const auto& node : GetGraphViewer().Nodes()) {
node.ForEachDef([&](const onnxruntime::NodeArg& node_arg, bool is_input) {
if (is_input) {
node_arg_use_count[node_arg.Name()]++;
}
});
}
Status SessionState::PrepackConstantInitializedTensors(std::unordered_map<std::string, size_t>& constant_initializers_use_count) {
for (auto& node : GetGraphViewer().Nodes()) {
auto kernel = GetMutableKernel(node.Index());
int input_idx = 0;
for (auto& input_def : node.InputDefs()) {
if (input_def->Exists()) {
const std::string& input_name = input_def->Name();
int ort_value_idx;
ORT_RETURN_IF_ERROR(ort_value_name_idx_map_.GetIdx(input_name, ort_value_idx));
if (constant_initialized_tensors_.count(ort_value_idx) &&
constant_initialized_tensors_[ort_value_idx].IsTensor()) {
bool is_packed = false;
const Tensor& const_initialized_tensor = constant_initialized_tensors_[ort_value_idx].Get<Tensor>();
ORT_RETURN_IF_ERROR(kernel->PrePack(const_initialized_tensor, input_idx, is_packed));
if (is_packed && node_arg_use_count.count(input_name) && --node_arg_use_count[input_name] == 0) {
// release the constant intialized tensor
initialized_tensors_.erase(ort_value_idx);
constant_initialized_tensors_.erase(ort_value_idx);
SessionState* st = this;
// subgraph can use the value from outer scope,
// so it needs to check if current node uses constant initialized tensor from current and outer graphs
do {
int ort_value_idx;
if (st->GetOrtValueNameIdxMap().GetIdx(input_name, ort_value_idx).IsOK()) {
std::unordered_map<int, OrtValue>& constant_initialized_tensors = st->constant_initialized_tensors_;
if (constant_initialized_tensors.count(ort_value_idx)) {
bool is_packed = false;
const Tensor& const_initialized_tensor = constant_initialized_tensors[ort_value_idx].Get<Tensor>();
ORT_RETURN_IF_ERROR(kernel->PrePack(const_initialized_tensor, input_idx, is_packed));
if (is_packed && constant_initializers_use_count.count(input_name) && --constant_initializers_use_count[input_name] == 0) {
// release the constant initialized tensor
st->initialized_tensors_.erase(ort_value_idx);
constant_initialized_tensors.erase(ort_value_idx);
}
}
// stop searching in 2 cases:
// 1. value is not from OuterScope
// 2. value is from OuterScope and the current OuterScope has the value
if (st != this || !st->graph_.IsOuterScopeValue(input_name)) {
break;
}
}
}
st = st->Parent();
} while (st);
}
input_idx++;
}
@ -567,10 +570,13 @@ void SessionState::AddSubgraphSessionState(onnxruntime::NodeIndex index, const s
ORT_ENFORCE(existing_entries.find(attribute_name) == existing_entries.cend(), "Entry exists in node ", index,
" for attribute ", attribute_name);
}
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
session_state->parent_ = this;
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
GenerateGraphId();
#endif
subgraph_session_states_[index].insert(std::make_pair(attribute_name, std::move(session_state)));
}
@ -776,6 +782,27 @@ Status SessionState::LoadFromOrtFormat(const fbs::SessionState& fbs_session_stat
}
#endif
// Calculate the use count of a constant initialized tensor, including the use in subgraph.
// Note: This function doesn't handle the case below:
// The main graph has a constant initializer called X, and the subgraph also has a constant initializer called X, which overrides the X from main graph.
// For case like this, the current implementation will calculate the use count as 2, but they could contain completely different values so each should have a use count of 1.
// This is a very rare case. If it happens and X is prepacked, the consequence is that X won't be released and memory usage of X won't be saved. This will be fine.
static void ComputeConstantInitializerUseCount(const Graph& graph, std::unordered_map<std::string, size_t>& constant_initializers_use_count) {
for (const auto& node : graph.Nodes()) {
for (const auto* arg : node.InputDefs()) {
if (arg->Exists() && graph.GetConstantInitializer(arg->Name(), true /*check_outer_scope*/)) {
constant_initializers_use_count[arg->Name()]++;
}
}
if (node.ContainsSubgraph()) {
for (const gsl::not_null<const Graph*>& subgraph : node.GetSubgraphs()) {
ComputeConstantInitializerUseCount(*subgraph, constant_initializers_use_count);
}
}
}
}
Status SessionState::FinalizeSessionState(const std::basic_string<PATH_CHAR_TYPE>& graph_location,
KernelRegistryManager& kernel_registry_manager,
const SessionOptions& session_options,
@ -807,15 +834,18 @@ Status SessionState::FinalizeSessionState(const std::basic_string<PATH_CHAR_TYPE
#endif
}
std::unordered_map<std::string, size_t> constant_initializers_use_count;
ComputeConstantInitializerUseCount(graph_, constant_initializers_use_count);
return FinalizeSessionStateImpl(graph_location, kernel_registry_manager, nullptr, session_options,
remove_initializers);
remove_initializers, constant_initializers_use_count);
}
Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_TYPE>& graph_location,
KernelRegistryManager& kernel_registry_manager,
_In_opt_ const Node* parent_node,
const SessionOptions& session_options,
bool remove_initializers) {
bool remove_initializers,
std::unordered_map<std::string, size_t>& constant_initializers_use_count) {
CreateGraphInfo();
// ignore any outer scope args we don't know about. this can happen if a node contains multiple subgraphs.
@ -868,7 +898,7 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_
session_options.GetConfigOrDefault(kOrtSessionOptionsConfigDisablePrepacking, "0");
if (disable_prepacking != "1") {
ORT_RETURN_IF_ERROR(PrepackInitializedConstantTensors());
ORT_RETURN_IF_ERROR(PrepackConstantInitializedTensors(constant_initializers_use_count));
}
ORT_RETURN_IF_ERROR(
@ -896,7 +926,7 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_
// recurse
ORT_RETURN_IF_ERROR(subgraph_session_state.FinalizeSessionStateImpl(
graph_location, kernel_registry_manager, &node, subgraph_session_options, remove_initializers));
graph_location, kernel_registry_manager, &node, subgraph_session_options, remove_initializers, constant_initializers_use_count));
// setup all the info for handling the feeds and fetches used in subgraph execution
auto* p_op_kernel = GetMutableKernel(node.Index());

View file

@ -279,6 +279,10 @@ class SessionState {
const onnxruntime::experimental::fbs::SessionState* serialized_session_state = nullptr,
bool remove_initializers = true);
SessionState* Parent() {
return parent_;
}
private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SessionState);
@ -298,7 +302,7 @@ class SessionState {
* Prepack the constant initialized tensors for better performance.
* The original constant initialized tensors will be removed to save memory.
*/
Status PrepackInitializedConstantTensors();
Status PrepackConstantInitializedTensors(std::unordered_map<std::string, size_t>& constant_initializers_use_count);
SessionState* GetMutableSubgraphSessionState(onnxruntime::NodeIndex index, const std::string& attribute_name);
@ -315,7 +319,8 @@ class SessionState {
KernelRegistryManager& kernel_registry_manager,
_In_opt_ const Node* parent_node,
const SessionOptions& session_options,
bool remove_initializers);
bool remove_initializers,
std::unordered_map<std::string, size_t>& constant_initializers_use_count);
#ifdef ENABLE_TRAINING
Status GeneratePatternGroupCache(
@ -421,9 +426,9 @@ class SessionState {
std::map<std::vector<int>, std::unordered_set<NodeIndex>> to_be_executed_nodes_;
#endif
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
SessionState* parent_ = nullptr;
//Assign each graph in each session an unique id.
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
int graph_id_ = 0;
int next_graph_id_ = 1;

View file

@ -189,20 +189,7 @@ class PrePackingTestOpKernel : public OpKernel {
}
};
class SessionStatePrepackingTest : public testing::TestWithParam<bool> {};
TEST_P(SessionStatePrepackingTest, PrePackingTest) {
OrtThreadPoolParams to;
auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to, concurrency::ThreadPoolType::INTRA_OP);
ONNX_OPERATOR_SCHEMA(PrePackingTest)
.SetDoc("Faking Node for PrePacking")
.Input(0, "Input_0", "input 0", "tensor(float)")
.Input(1, "Input_1", "input 1", "tensor(float)")
.Output(0, "output_0", "docstr for output_0.", "tensor(float)");
onnxruntime::Model model("graph_1", false, DefaultLoggingManager().DefaultLogger());
// construct graph
auto& graph = model.MainGraph();
static void CreateSimpleGraph(Graph& graph) {
// node creation and placement
TypeProto type;
type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
@ -218,8 +205,7 @@ TEST_P(SessionStatePrepackingTest, PrePackingTest) {
onnxruntime::NodeArg output_arg("node_0_output_0", &type);
outputs.push_back(&output_arg);
onnxruntime::Node& node = graph.AddNode("node_0", "PrePackingTest", "node 0", inputs, outputs);
node.SetExecutionProviderType(kCpuExecutionProvider);
graph.AddNode("node_0", "PrePackingTest", "node 0", inputs, outputs);
// add an initializer
ONNX_NAMESPACE::TensorProto tensor;
@ -231,6 +217,123 @@ TEST_P(SessionStatePrepackingTest, PrePackingTest) {
auto status = graph.Resolve();
ASSERT_TRUE(status.IsOK());
}
static const ONNX_NAMESPACE::GraphProto CreateSubgraph(bool then_branch) {
Model model(then_branch ? "If_then" : "If_else", false, DefaultLoggingManager().DefaultLogger());
auto& graph = model.MainGraph();
std::vector<NodeArg*> inputs;
std::vector<NodeArg*> outputs;
const std::string suffix = then_branch ? "0" : "1";
// graph input has to have type and rank even though it's an outer scope value.
TypeProto type_float;
type_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
type_float.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
// outer scope values
auto& if_shared = graph.GetOrCreateNodeArg("if_shared", &type_float);
auto& if_input = graph.GetOrCreateNodeArg("if_input_" + suffix, &type_float);
// add so that we don't end up with it being considered a graph input
graph.AddOuterScopeNodeArg("if_shared");
graph.AddOuterScopeNodeArg("if_input_" + suffix);
auto& if_out = graph.GetOrCreateNodeArg("if_output_" + suffix, &type_float);
inputs = {&if_shared, &if_input};
outputs = {&if_out};
graph.AddNode("if_node_" + suffix, "PrePackingTest", "if node " + suffix, inputs, outputs);
auto status = graph.Resolve();
EXPECT_EQ(status, Status::OK());
auto& proto = graph.ToGraphProto();
return proto;
}
static void CreateGraphWithSubgraph(Graph& graph) {
TypeProto type_float;
type_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
type_float.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
{
std::vector<onnxruntime::NodeArg*> inputs;
onnxruntime::NodeArg input_0_arg("if_input_0", &type_float);
onnxruntime::NodeArg input_1_arg("if_input_1", &type_float);
inputs.push_back(&input_0_arg);
inputs.push_back(&input_1_arg);
std::vector<onnxruntime::NodeArg*> outputs;
onnxruntime::NodeArg output_arg("node_0_output_0", &type_float);
outputs.push_back(&output_arg);
graph.AddNode("node_0", "PrePackingTest", "node 0", inputs, outputs);
}
{
TypeProto type_bool;
type_bool.mutable_tensor_type()->set_elem_type(TensorProto_DataType_BOOL);
type_bool.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
onnxruntime::NodeArg bool_arg("bool_arg", &type_bool);
std::vector<onnxruntime::NodeArg*> outputs;
onnxruntime::NodeArg output_arg("output_arg", &type_float);
outputs.push_back(&output_arg);
auto& if_node = graph.AddNode("if", "If", "If node", {&bool_arg}, outputs);
auto then_proto = CreateSubgraph(true);
auto else_proto = CreateSubgraph(false);
if_node.AddAttribute("then_branch", then_proto);
if_node.AddAttribute("else_branch", else_proto);
}
// add an initializer
ONNX_NAMESPACE::TensorProto tensor;
tensor.add_dims(1);
tensor.add_float_data(1.0f);
tensor.set_data_type(TensorProto_DataType_FLOAT);
tensor.set_name("if_shared");
graph.AddInitializedTensor(tensor);
auto status = graph.Resolve();
ASSERT_TRUE(status.IsOK());
}
static void PlaceAllNodesToCPUEP(Graph& graph) {
for (auto& node : graph.Nodes()) {
node.SetExecutionProviderType(kCpuExecutionProvider);
if (node.ContainsSubgraph()) {
for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) {
Graph* subgraph = entry.second;
PlaceAllNodesToCPUEP(*subgraph);
}
}
}
}
struct PrepackingTestParam {
bool test_subgraph;
bool test_prepacking;
};
class SessionStatePrepackingTest : public testing::TestWithParam<PrepackingTestParam> {};
TEST_P(SessionStatePrepackingTest, PrePackingTest) {
PrepackingTestParam test_param = GetParam();
OrtThreadPoolParams to;
auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to, concurrency::ThreadPoolType::INTRA_OP);
ONNX_OPERATOR_SCHEMA(PrePackingTest)
.SetDoc("Faking Node for PrePacking")
.Input(0, "Input_0", "input 0", "tensor(float)")
.Input(1, "Input_1", "input 1", "tensor(float)")
.Output(0, "output_0", "docstr for output_0.", "tensor(float)");
ExecutionProviders execution_providers;
auto cpu_execution_provider = onnxruntime::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo(false));
@ -238,7 +341,21 @@ TEST_P(SessionStatePrepackingTest, PrePackingTest) {
DataTransferManager dtm;
profiling::Profiler profiler;
SessionState session_state(graph,
std::unordered_map<std::string, int> domain_to_version;
domain_to_version[kOnnxDomain] = 11;
Model model("graph_main", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
domain_to_version, std::vector<ONNX_NAMESPACE::FunctionProto>(),
DefaultLoggingManager().DefaultLogger());
// onnxruntime::Model model("graph_main", false, DefaultLoggingManager().DefaultLogger());
if (test_param.test_subgraph) {
CreateGraphWithSubgraph(model.MainGraph());
} else {
CreateSimpleGraph(model.MainGraph());
}
SessionState session_state(model.MainGraph(),
execution_providers,
true, /*enable_mem_pattern*/
tp.get(),
@ -248,7 +365,7 @@ TEST_P(SessionStatePrepackingTest, PrePackingTest) {
profiler);
KernelRegistryManager kernel_registry_manager;
status = kernel_registry_manager.RegisterKernels(execution_providers);
Status status = kernel_registry_manager.RegisterKernels(execution_providers);
ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();
std::shared_ptr<KernelRegistry> kernel_registry = std::make_shared<KernelRegistry>();
auto kernel_def = KernelDefBuilder().SetName("PrePackingTest").Provider(kCpuExecutionProvider).SinceVersion(1).Build();
@ -257,19 +374,25 @@ TEST_P(SessionStatePrepackingTest, PrePackingTest) {
[](const OpKernelInfo& info) -> OpKernel* { return new PrePackingTestOpKernel(info); })));
kernel_registry_manager.RegisterKernelRegistry(kernel_registry);
PlaceAllNodesToCPUEP(model.MainGraph());
SessionOptions sess_options;
bool use_prepacking = GetParam();
sess_options.session_configurations[kOrtSessionOptionsConfigDisablePrepacking] = use_prepacking ? "0" : "1";
sess_options.session_configurations[kOrtSessionOptionsConfigDisablePrepacking] = test_param.test_prepacking ? "0" : "1";
ASSERT_STATUS_OK(session_state.FinalizeSessionState(std::basic_string<PATH_CHAR_TYPE>(),
kernel_registry_manager,
sess_options));
const auto& const_initialized_tensors = session_state.GetConstantInitializedTensors();
// check prepacking
ASSERT_EQ(const_initialized_tensors.size(), size_t(use_prepacking ? 0 : 1));
ASSERT_EQ(const_initialized_tensors.size(), size_t(test_param.test_prepacking ? 0 : 1));
}
INSTANTIATE_TEST_SUITE_P(SessionStateTests, SessionStatePrepackingTest, testing::Values(true, false));
INSTANTIATE_TEST_SUITE_P(SessionStateTests,
SessionStatePrepackingTest,
testing::Values(PrepackingTestParam{false, false},
PrepackingTestParam{false, true},
PrepackingTestParam{true, false},
PrepackingTestParam{true, true}));
} // namespace test
} // namespace onnxruntime