mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Fix bug in CPU force fallback logic (#8597)
This commit is contained in:
parent
f3a1aebb33
commit
e791faeca5
3 changed files with 34 additions and 3 deletions
|
|
@ -654,6 +654,14 @@ class Graph {
|
|||
*/
|
||||
const ONNX_NAMESPACE::TensorProto* GetConstantInitializer(const std::string& name, bool check_outer_scope) const;
|
||||
|
||||
/** returns the initializer's TensorProto if 'name' is an initializer (both constant and overridable).
|
||||
If the initializer is not found, a nullptr is returned.
|
||||
@param check_outer_scope If true and the graph is a subgraph,
|
||||
check ancestor graph/s for 'name' if not found in 'graph'.
|
||||
@remarks check_outer_scope of true is not supported in a minimal build
|
||||
*/
|
||||
const ONNX_NAMESPACE::TensorProto* GetInitializer(const std::string& name, bool check_outer_scope) const;
|
||||
|
||||
/** Gets the Graph inputs excluding initializers.
|
||||
These are the required inputs to the Graph as the initializers can be optionally overridden via graph inputs.
|
||||
@remarks Contains no nullptr values. */
|
||||
|
|
|
|||
|
|
@ -16,10 +16,18 @@ namespace onnxruntime {
|
|||
namespace {
|
||||
const int64_t kSmallInitializerThreshold = 100;
|
||||
|
||||
bool IsSmallInitializer(const onnxruntime::GraphViewer& graph, const NodeArg* arg) {
|
||||
const ONNX_NAMESPACE::TensorProto* initializer_tensor;
|
||||
if (!graph.GetInitializedTensor(arg->Name(), initializer_tensor))
|
||||
static bool IsSmallInitializer(const onnxruntime::GraphViewer& graph, const NodeArg* arg) {
|
||||
// 'true' in the function call is to let the searching for the initializer
|
||||
// continue in the outer scopes of the current (sub-)graph if applicable
|
||||
const ONNX_NAMESPACE::TensorProto* initializer_tensor =
|
||||
graph.GetGraph().GetInitializer(arg->Name(), true);
|
||||
|
||||
// Not an initializer at all
|
||||
if (initializer_tensor == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if "small" enough
|
||||
int64_t size = 1;
|
||||
for (auto& dim : initializer_tensor->dims()) {
|
||||
size *= dim;
|
||||
|
|
|
|||
|
|
@ -2787,6 +2787,21 @@ const ONNX_NAMESPACE::TensorProto* Graph::GetConstantInitializer(const std::stri
|
|||
return initializer;
|
||||
}
|
||||
|
||||
const ONNX_NAMESPACE::TensorProto* Graph::GetInitializer(const std::string& initializer_name,
|
||||
bool check_outer_scope) const {
|
||||
const ONNX_NAMESPACE::TensorProto* initializer = nullptr;
|
||||
if (GetInitializedTensor(initializer_name, initializer)) {
|
||||
return initializer;
|
||||
} else if (check_outer_scope && IsSubgraph()) {
|
||||
// make sure there's not a local value with the same name. if there is it shadows any initializer in outer scope.
|
||||
if (IsOuterScopeValue(initializer_name)) {
|
||||
initializer = parent_graph_->GetInitializer(initializer_name, check_outer_scope);
|
||||
}
|
||||
}
|
||||
|
||||
return initializer;
|
||||
}
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
void Graph::AddValueInfo(const NodeArg* new_value_info) {
|
||||
NodeArg* node_arg = GetNodeArg(new_value_info->Name());
|
||||
|
|
|
|||
Loading…
Reference in a new issue