diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 361047e8d8..be49b2add7 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -654,6 +654,14 @@ class Graph { */ const ONNX_NAMESPACE::TensorProto* GetConstantInitializer(const std::string& name, bool check_outer_scope) const; + /** returns the initializer's TensorProto if 'name' is an initializer (both constant and overridable). + If the initializer is not found, a nullptr is returned. + @param check_outer_scope If true and the graph is a subgraph, + check ancestor graph/s for 'name' if not found in 'graph'. + @remarks check_outer_scope of true is not supported in a minimal build + */ + const ONNX_NAMESPACE::TensorProto* GetInitializer(const std::string& name, bool check_outer_scope) const; + /** Gets the Graph inputs excluding initializers. These are the required inputs to the Graph as the initializers can be optionally overridden via graph inputs. @remarks Contains no nullptr values. */ diff --git a/onnxruntime/core/framework/fallback_cpu_capability.cc b/onnxruntime/core/framework/fallback_cpu_capability.cc index 011eaf9edb..28e309f26d 100644 --- a/onnxruntime/core/framework/fallback_cpu_capability.cc +++ b/onnxruntime/core/framework/fallback_cpu_capability.cc @@ -16,10 +16,18 @@ namespace onnxruntime { namespace { const int64_t kSmallInitializerThreshold = 100; -bool IsSmallInitializer(const onnxruntime::GraphViewer& graph, const NodeArg* arg) { - const ONNX_NAMESPACE::TensorProto* initializer_tensor; - if (!graph.GetInitializedTensor(arg->Name(), initializer_tensor)) +static bool IsSmallInitializer(const onnxruntime::GraphViewer& graph, const NodeArg* arg) { + // 'true' in the function call is to let the searching for the initializer + // continue in the outer scopes of the current (sub-)graph if applicable + const ONNX_NAMESPACE::TensorProto* initializer_tensor = + graph.GetGraph().GetInitializer(arg->Name(), true); + + // Not an initializer at all + if (initializer_tensor == nullptr) { return false; + } + + // Check if "small" enough int64_t size = 1; for (auto& dim : initializer_tensor->dims()) { size *= dim; diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 59bca28e45..77174cde50 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -2787,6 +2787,21 @@ const ONNX_NAMESPACE::TensorProto* Graph::GetConstantInitializer(const std::stri return initializer; } +const ONNX_NAMESPACE::TensorProto* Graph::GetInitializer(const std::string& initializer_name, + bool check_outer_scope) const { + const ONNX_NAMESPACE::TensorProto* initializer = nullptr; + if (GetInitializedTensor(initializer_name, initializer)) { + return initializer; + } else if (check_outer_scope && IsSubgraph()) { + // make sure there's not a local value with the same name. if there is it shadows any initializer in outer scope. + if (IsOuterScopeValue(initializer_name)) { + initializer = parent_graph_->GetInitializer(initializer_name, check_outer_scope); + } + } + + return initializer; +} + #if !defined(ORT_MINIMAL_BUILD) void Graph::AddValueInfo(const NodeArg* new_value_info) { NodeArg* node_arg = GetNodeArg(new_value_info->Name());