diff --git a/onnxruntime/core/graph/function_utils.cc b/onnxruntime/core/graph/function_utils.cc index aa0727e375..7477f48088 100644 --- a/onnxruntime/core/graph/function_utils.cc +++ b/onnxruntime/core/graph/function_utils.cc @@ -269,7 +269,7 @@ static void IOTypeConstraintHelper(const ONNX_NAMESPACE::FunctionProto& onnx_fun std::unique_ptr CreateSchema(const std::string& function_domain, const std::string& function_name, - const InlinedHashMap& model_local_functions, + const std::unordered_map& model_local_functions, const std::unordered_map& domain_version_map, const SchemaRegistryManager& schema_registry, const logging::Logger& logger, @@ -315,6 +315,7 @@ std::unique_ptr CreateSchema(const std::string& functi schema_registry.GetLastReleasedOpsetVersions(false); std::unordered_map func_domain_to_version; + func_domain_to_version.reserve(onnx_func_proto->opset_import().size()); for (auto& opSet : onnx_func_proto->opset_import()) { const auto& domain = opSet.domain(); const auto version = gsl::narrow_cast(opSet.version()); @@ -332,18 +333,16 @@ std::unique_ptr CreateSchema(const std::string& functi } } + // Instantiate once and reuse for all shape inference calls. + constexpr bool check_type_true = true; + constexpr int error_mode_throw = 1; + constexpr bool enable_data_propagation_false = false; + static const ONNX_NAMESPACE::ShapeInferenceOptions inference_options{check_type_true, error_mode_throw, enable_data_propagation_false}; + + // model_local_functions is a member of Model instance and will be alive at the time this is invoked. op_schema->TypeAndShapeInferenceFunction( - [onnx_func_proto, func_domain_to_version, &model_local_functions](ONNX_NAMESPACE::InferenceContext& ctx) { - auto schema_registry = ONNX_NAMESPACE::OpSchemaRegistry::Instance(); - - constexpr bool check_type_true = true; - constexpr int error_mode_throw = 1; - constexpr bool enable_data_propagation_false = false; - ONNX_NAMESPACE::ShapeInferenceOptions options{check_type_true, error_mode_throw, enable_data_propagation_false}; - - std::unordered_map map_copy(model_local_functions.begin(), - model_local_functions.end()); - std::unordered_map empty_map; + [onnx_func_proto, func_domain_to_version = std::move(func_domain_to_version), &model_local_functions](ONNX_NAMESPACE::InferenceContext& ctx) { + auto* schema_registry = ONNX_NAMESPACE::OpSchemaRegistry::Instance(); // https://github.com/microsoft/onnxruntime/issues/17061 // We are passing a nullptr for the symbol table, because symbol table must be global @@ -351,8 +350,8 @@ std::unique_ptr CreateSchema(const std::string& functi // the same symbolic shapes and are marked for memory re-use. This is a Temp fix. constexpr ONNX_NAMESPACE::shape_inference::SymbolTableImpl* symbolTable = nullptr; ONNX_NAMESPACE::shape_inference::InferShapeForFunctionNode(*onnx_func_proto, func_domain_to_version, - schema_registry, ctx, options, map_copy, - symbolTable, &empty_map); + schema_registry, ctx, inference_options, model_local_functions, + symbolTable, nullptr); }); op_schema->Finalize(); diff --git a/onnxruntime/core/graph/function_utils.h b/onnxruntime/core/graph/function_utils.h index d2bb86d107..34e5e57189 100644 --- a/onnxruntime/core/graph/function_utils.h +++ b/onnxruntime/core/graph/function_utils.h @@ -31,6 +31,8 @@ std::unique_ptr CreateSchema(const Graph& graph, * @param function_name The name of the function. * @param model_local_functions The map of local functions in the same onnx model. * This will be used as context for the function's type/shape inference. + * This argument is captured by shape inferencing lambda by reference and must + * be alive at the time of the shape inferencing. * @param domain_version_map Domain to version map used in current onnx model. * @param schema_registry The schema registry current model is using. * @param logger The logger current model is using. @@ -38,7 +40,7 @@ std::unique_ptr CreateSchema(const Graph& graph, */ std::unique_ptr CreateSchema(const std::string& function_domain, const std::string& function_name, - const InlinedHashMap& model_local_functions, + const std::unordered_map& model_local_functions, const std::unordered_map& domain_version_map, const SchemaRegistryManager& schema_registry, const logging::Logger& logger, diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index d206af1acf..05747a7e51 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -91,10 +91,11 @@ Model::Model(const std::string& graph_name, opset_id_proto->set_version(version); } + model_local_functions_.reserve(model_local_functions.size()); for (auto& func : model_local_functions) { auto func_ptr = model_proto_.add_functions(); func_ptr->CopyFrom(func); - model_local_functions_[function_utils::GetFunctionIdentifier(func_ptr->domain(), func_ptr->name())] = func_ptr; + model_local_functions_.insert_or_assign(function_utils::GetFunctionIdentifier(func_ptr->domain(), func_ptr->name()), func_ptr); } model_local_function_templates_.reserve(model_proto_.functions().size()); @@ -214,9 +215,9 @@ Model::Model(ModelProto&& model_proto, const PathString& model_path, } } - std::vector model_local_functions; + model_local_functions_.reserve(model_proto_.functions().size()); for (auto& func : model_proto_.functions()) { - model_local_functions_[function_utils::GetFunctionIdentifier(func.domain(), func.name())] = &func; + model_local_functions_.insert_or_assign(function_utils::GetFunctionIdentifier(func.domain(), func.name()), &func); } model_local_function_templates_.reserve(model_proto_.functions().size()); diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h index 7e3942b029..6bdb68dd73 100644 --- a/onnxruntime/core/graph/model.h +++ b/onnxruntime/core/graph/model.h @@ -310,7 +310,8 @@ class Model { // map from function id to pointer of model local function proto // FunctionProto is hosted in ModelProto. // this map will be used for the local functions' schema's type/shape inference. - InlinedHashMap model_local_functions_; + // This container is used by ONNX code and must be an std::unordered_map. + std::unordered_map model_local_functions_; // this is the container that host the generated schemas for model local functions. // the generated schemare will be used for graph resolving and type/shape inference. // those schemas' type/shape inference will reference to the model_local_functions_ as context,