Fix bug in CPU force fallback logic (#8597)

This commit is contained in:
Hariharan Seshadri 2021-08-05 21:36:28 -07:00 committed by GitHub
parent f3a1aebb33
commit e791faeca5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 34 additions and 3 deletions

View file

@ -654,6 +654,14 @@ class Graph {
*/
const ONNX_NAMESPACE::TensorProto* GetConstantInitializer(const std::string& name, bool check_outer_scope) const;
/** returns the initializer's TensorProto if 'name' is an initializer (both constant and overridable).
If the initializer is not found, a nullptr is returned.
@param check_outer_scope If true and the graph is a subgraph,
check ancestor graph/s for 'name' if not found in 'graph'.
@remarks check_outer_scope of true is not supported in a minimal build
*/
const ONNX_NAMESPACE::TensorProto* GetInitializer(const std::string& name, bool check_outer_scope) const;
/** Gets the Graph inputs excluding initializers.
These are the required inputs to the Graph as the initializers can be optionally overridden via graph inputs.
@remarks Contains no nullptr values. */

View file

@ -16,10 +16,18 @@ namespace onnxruntime {
namespace {
const int64_t kSmallInitializerThreshold = 100;
bool IsSmallInitializer(const onnxruntime::GraphViewer& graph, const NodeArg* arg) {
const ONNX_NAMESPACE::TensorProto* initializer_tensor;
if (!graph.GetInitializedTensor(arg->Name(), initializer_tensor))
static bool IsSmallInitializer(const onnxruntime::GraphViewer& graph, const NodeArg* arg) {
// 'true' in the function call is to let the searching for the initializer
// continue in the outer scopes of the current (sub-)graph if applicable
const ONNX_NAMESPACE::TensorProto* initializer_tensor =
graph.GetGraph().GetInitializer(arg->Name(), true);
// Not an initializer at all
if (initializer_tensor == nullptr) {
return false;
}
// Check if "small" enough
int64_t size = 1;
for (auto& dim : initializer_tensor->dims()) {
size *= dim;

View file

@ -2787,6 +2787,21 @@ const ONNX_NAMESPACE::TensorProto* Graph::GetConstantInitializer(const std::stri
return initializer;
}
const ONNX_NAMESPACE::TensorProto* Graph::GetInitializer(const std::string& initializer_name,
bool check_outer_scope) const {
const ONNX_NAMESPACE::TensorProto* initializer = nullptr;
if (GetInitializedTensor(initializer_name, initializer)) {
return initializer;
} else if (check_outer_scope && IsSubgraph()) {
// make sure there's not a local value with the same name. if there is it shadows any initializer in outer scope.
if (IsOuterScopeValue(initializer_name)) {
initializer = parent_graph_->GetInitializer(initializer_name, check_outer_scope);
}
}
return initializer;
}
#if !defined(ORT_MINIMAL_BUILD)
void Graph::AddValueInfo(const NodeArg* new_value_info) {
NodeArg* node_arg = GetNodeArg(new_value_info->Name());