Eliminate hashmap copies during function inlining (#17439)

### Description
Eliminate unnecessary HashMap copies. This saves 22% of CPU usage on a
reference Dynamo exported model.

### Motivation and Context
Our function inlining is currently slow.

Before:


![image](https://github.com/microsoft/onnxruntime/assets/11303988/fd38a857-8c12-42ef-9de2-3485123a9fe7)

After


![image](https://github.com/microsoft/onnxruntime/assets/11303988/ea65813d-26cb-41dc-ba55-6a609b169767)
This commit is contained in:
Dmitri Smirnov 2023-09-07 14:08:38 -07:00 committed by GitHub
parent 024f1dd72b
commit 21c202bb5d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 22 additions and 19 deletions

View file

@ -269,7 +269,7 @@ static void IOTypeConstraintHelper(const ONNX_NAMESPACE::FunctionProto& onnx_fun
std::unique_ptr<ONNX_NAMESPACE::OpSchema> CreateSchema(const std::string& function_domain,
const std::string& function_name,
const InlinedHashMap<std::string, const ONNX_NAMESPACE::FunctionProto*>& model_local_functions,
const std::unordered_map<std::string, const ONNX_NAMESPACE::FunctionProto*>& model_local_functions,
const std::unordered_map<std::string, int>& domain_version_map,
const SchemaRegistryManager& schema_registry,
const logging::Logger& logger,
@ -315,6 +315,7 @@ std::unique_ptr<ONNX_NAMESPACE::OpSchema> CreateSchema(const std::string& functi
schema_registry.GetLastReleasedOpsetVersions(false);
std::unordered_map<std::string, int> func_domain_to_version;
func_domain_to_version.reserve(onnx_func_proto->opset_import().size());
for (auto& opSet : onnx_func_proto->opset_import()) {
const auto& domain = opSet.domain();
const auto version = gsl::narrow_cast<int>(opSet.version());
@ -332,18 +333,16 @@ std::unique_ptr<ONNX_NAMESPACE::OpSchema> CreateSchema(const std::string& functi
}
}
// Instantiate once and reuse for all shape inference calls.
constexpr bool check_type_true = true;
constexpr int error_mode_throw = 1;
constexpr bool enable_data_propagation_false = false;
static const ONNX_NAMESPACE::ShapeInferenceOptions inference_options{check_type_true, error_mode_throw, enable_data_propagation_false};
// model_local_functions is a member of Model instance and will be alive at the time this is invoked.
op_schema->TypeAndShapeInferenceFunction(
[onnx_func_proto, func_domain_to_version, &model_local_functions](ONNX_NAMESPACE::InferenceContext& ctx) {
auto schema_registry = ONNX_NAMESPACE::OpSchemaRegistry::Instance();
constexpr bool check_type_true = true;
constexpr int error_mode_throw = 1;
constexpr bool enable_data_propagation_false = false;
ONNX_NAMESPACE::ShapeInferenceOptions options{check_type_true, error_mode_throw, enable_data_propagation_false};
std::unordered_map<std::string, const ONNX_NAMESPACE::FunctionProto*> map_copy(model_local_functions.begin(),
model_local_functions.end());
std::unordered_map<std::string, TensorShapeProto> empty_map;
[onnx_func_proto, func_domain_to_version = std::move(func_domain_to_version), &model_local_functions](ONNX_NAMESPACE::InferenceContext& ctx) {
auto* schema_registry = ONNX_NAMESPACE::OpSchemaRegistry::Instance();
// https://github.com/microsoft/onnxruntime/issues/17061
// We are passing a nullptr for the symbol table, because symbol table must be global
@ -351,8 +350,8 @@ std::unique_ptr<ONNX_NAMESPACE::OpSchema> CreateSchema(const std::string& functi
// the same symbolic shapes and are marked for memory re-use. This is a Temp fix.
constexpr ONNX_NAMESPACE::shape_inference::SymbolTableImpl* symbolTable = nullptr;
ONNX_NAMESPACE::shape_inference::InferShapeForFunctionNode(*onnx_func_proto, func_domain_to_version,
schema_registry, ctx, options, map_copy,
symbolTable, &empty_map);
schema_registry, ctx, inference_options, model_local_functions,
symbolTable, nullptr);
});
op_schema->Finalize();

View file

@ -31,6 +31,8 @@ std::unique_ptr<ONNX_NAMESPACE::OpSchema> CreateSchema(const Graph& graph,
* @param function_name The name of the function.
* @param model_local_functions The map of local functions in the same onnx model.
* This will be used as context for the function's type/shape inference.
* This argument is captured by shape inferencing lambda by reference and must
* be alive at the time of the shape inferencing.
* @param domain_version_map Domain to version map used in current onnx model.
* @param schema_registry The schema registry current model is using.
* @param logger The logger current model is using.
@ -38,7 +40,7 @@ std::unique_ptr<ONNX_NAMESPACE::OpSchema> CreateSchema(const Graph& graph,
*/
std::unique_ptr<ONNX_NAMESPACE::OpSchema> CreateSchema(const std::string& function_domain,
const std::string& function_name,
const InlinedHashMap<std::string, const ONNX_NAMESPACE::FunctionProto*>& model_local_functions,
const std::unordered_map<std::string, const ONNX_NAMESPACE::FunctionProto*>& model_local_functions,
const std::unordered_map<std::string, int>& domain_version_map,
const SchemaRegistryManager& schema_registry,
const logging::Logger& logger,

View file

@ -91,10 +91,11 @@ Model::Model(const std::string& graph_name,
opset_id_proto->set_version(version);
}
model_local_functions_.reserve(model_local_functions.size());
for (auto& func : model_local_functions) {
auto func_ptr = model_proto_.add_functions();
func_ptr->CopyFrom(func);
model_local_functions_[function_utils::GetFunctionIdentifier(func_ptr->domain(), func_ptr->name())] = func_ptr;
model_local_functions_.insert_or_assign(function_utils::GetFunctionIdentifier(func_ptr->domain(), func_ptr->name()), func_ptr);
}
model_local_function_templates_.reserve(model_proto_.functions().size());
@ -214,9 +215,9 @@ Model::Model(ModelProto&& model_proto, const PathString& model_path,
}
}
std::vector<const ONNX_NAMESPACE::FunctionProto*> model_local_functions;
model_local_functions_.reserve(model_proto_.functions().size());
for (auto& func : model_proto_.functions()) {
model_local_functions_[function_utils::GetFunctionIdentifier(func.domain(), func.name())] = &func;
model_local_functions_.insert_or_assign(function_utils::GetFunctionIdentifier(func.domain(), func.name()), &func);
}
model_local_function_templates_.reserve(model_proto_.functions().size());

View file

@ -310,7 +310,8 @@ class Model {
// map from function id to pointer of model local function proto
// FunctionProto is hosted in ModelProto.
// this map will be used for the local functions' schema's type/shape inference.
InlinedHashMap<std::string, const ONNX_NAMESPACE::FunctionProto*> model_local_functions_;
// This container is used by ONNX code and must be an std::unordered_map.
std::unordered_map<std::string, const ONNX_NAMESPACE::FunctionProto*> model_local_functions_;
// this is the container that host the generated schemas for model local functions.
// the generated schemare will be used for graph resolving and type/shape inference.
// those schemas' type/shape inference will reference to the model_local_functions_ as context,