Fix location planning for initializers used only in nested subgraphs (#8642)

This commit is contained in:
Hariharan Seshadri 2021-09-01 00:02:08 -07:00 committed by GitHub
parent 4dc0ddf606
commit acd9db7fad
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 324 additions and 42 deletions

View file

@ -381,6 +381,12 @@ class Node {
return attr_to_subgraph_map_;
}
/** Gets a map of attribute name to the const Graph instances for all subgraphs of the Node.
@returns Map of the attribute name that defines the subgraph to the subgraph's Graph instance.
nullptr if the Node has no subgraphs.
*/
std::unordered_map<std::string, gsl::not_null<const Graph*>> GetAttributeNameToSubgraphMap() const;
/** Gets the execution ProviderType that this node will be executed by. */
ProviderType GetExecutionProviderType() const noexcept { return execution_provider_type_; }

View file

@ -20,6 +20,27 @@ using namespace onnxruntime::common;
using namespace ONNX_NAMESPACE;
namespace onnxruntime {
namespace NestedSubgraphInfoDetails {
// Used to compose a unique key to identify a nested subgraph
// relative to a current graph level (which in turn is identified using a "base")
std::string ComposeNestedSubgraphInfoKeyHelper(const std::string& base,
size_t graph_depth,
NodeIndex node_index,
const std::string& attr_name) {
std::ostringstream ss;
// key = base + graph depth + current graph node index + attr name corresponding to the subgraph
ss << base;
ss << graph_depth;
ss << node_index;
ss << attr_name;
return ss.str();
}
} // namespace NestedSubgraphInfoDetails
std::ostream& operator<<(std::ostream& out, AllocKind alloc_kind) {
switch (alloc_kind) {
case AllocKind::kAllocate:
@ -107,7 +128,7 @@ std::ostream& operator<<(std::ostream& out, std::pair<const SequentialExecutionP
}
static const KernelCreateInfo& GetKernelCreateInfo(
const std::unordered_map<NodeIndex, gsl::not_null<const KernelCreateInfo*>>& kernel_create_info_map,
const KernelCreateInfoMap& kernel_create_info_map,
NodeIndex node_index) {
auto entry = kernel_create_info_map.find(node_index);
ORT_ENFORCE(entry != kernel_create_info_map.cend(),
@ -120,7 +141,8 @@ class PlannerImpl {
public:
PlannerImpl(const Node* parent_node, const onnxruntime::GraphViewer& graph_viewer,
const std::vector<const NodeArg*>& outer_scope_node_args, const ExecutionProviders& providers,
const std::unordered_map<NodeIndex, gsl::not_null<const KernelCreateInfo*>>& kernel_create_info_map,
const KernelCreateInfoMap& kernel_create_info_map,
const SubgraphsKernelCreateInfoMaps& subgraphs_kernel_create_info_maps,
const std::unordered_map<OrtValueName, OrtMemoryInfo>& outer_scope_node_arg_to_location_map,
const OrtValueNameIdxMap& ort_value_name_idx_map,
const ISequentialPlannerContext& context, SequentialExecutionPlan& plan)
@ -131,6 +153,7 @@ class PlannerImpl {
outer_scope_node_args_(outer_scope_node_args),
execution_providers_(providers),
kernel_create_info_map_(kernel_create_info_map),
subgraphs_kernel_create_info_maps_(subgraphs_kernel_create_info_maps),
outer_scope_node_arg_to_location_map_(outer_scope_node_arg_to_location_map),
ort_value_name_idx_map_(ort_value_name_idx_map) {}
@ -145,7 +168,8 @@ class PlannerImpl {
const std::vector<const NodeArg*>& outer_scope_node_args_;
const ExecutionProviders& execution_providers_;
const std::unordered_map<NodeIndex, gsl::not_null<const KernelCreateInfo*>>& kernel_create_info_map_;
const KernelCreateInfoMap& kernel_create_info_map_;
const SubgraphsKernelCreateInfoMaps& subgraphs_kernel_create_info_maps_;
const std::unordered_map<OrtValueName, OrtMemoryInfo>& outer_scope_node_arg_to_location_map_;
@ -609,11 +633,12 @@ class PlannerImpl {
return Status::OK();
}
OrtMemoryInfo GetLocationForNodeInput(size_t input_index, const Node& node) {
OrtMemoryInfo GetLocationForNodeInput(size_t input_index, const Node& node,
const KernelCreateInfoMap& kernel_create_info_map) {
auto* p_provider = execution_providers_.Get(node);
ORT_ENFORCE(p_provider);
const KernelCreateInfo& kernel_create_info = GetKernelCreateInfo(kernel_create_info_map_, node.Index());
const KernelCreateInfo& kernel_create_info = GetKernelCreateInfo(kernel_create_info_map, node.Index());
if (utils::IsInputOnCpu(node, &kernel_create_info, input_index))
// weights are not output from any node, so it's OK to put its location on CPU provider
@ -621,29 +646,95 @@ class PlannerImpl {
return p_provider->GetAllocator(0, OrtMemTypeDefault)->Info();
}
Status GeneratePlanForWeights() {
auto& weights = graph_viewer_.GetAllInitializedTensors();
std::vector<std::vector<OrtMemoryInfo>> locations(plan_.allocation_plan.size());
for (const auto& node : graph_viewer_.Nodes()) {
auto status = onnxruntime::Node::ForEachWithIndex(
node.InputDefs(), [this, &locations, &node, &weights](const onnxruntime::NodeArg& def, size_t index) {
auto sub_status = Status::OK();
ORT_TRY {
auto& def_name = def.Name();
if (!weights.count(def_name)) return Status::OK();
auto wt_index = Index(def_name);
locations[wt_index].emplace_back(GetLocationForNodeInput(index, node));
}
ORT_CATCH(const std::exception& ex) {
ORT_HANDLE_EXCEPTION([&]() {
sub_status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, ex.what());
});
}
return sub_status;
});
void GeneratePlanForWeightsHelper(const GraphViewer& graph_viewer,
const InitializedTensorSet& weights,
const KernelCreateInfoMap& kernel_create_info_map,
const std::string& subgraph_kernel_create_info_map_key_base,
size_t graph_depth,
/*out*/ std::vector<std::vector<OrtMemoryInfo>>& locations) {
for (const auto& node : graph_viewer.Nodes()) {
const auto& input_node_args = node.InputDefs();
size_t num_node_inputs = input_node_args.size();
ORT_RETURN_IF_ERROR(status);
for (size_t node_input_index = 0; node_input_index < num_node_inputs; ++node_input_index) {
auto input_node_arg = input_node_args[node_input_index];
// Skip processing missing optional inputs
if (!input_node_arg->Exists()) {
continue;
}
auto& def_name = input_node_arg->Name();
// This node input doesn't correspond to any of the weights
if (!weights.count(def_name)) {
continue;
}
// While processing subgraphs, if we don't see an entry in the implicit
// inputs of the node containing the subgraph, it is a shadow value.
auto is_shadow_value_in_subgraph = [](const Node& subgraph_parent_node,
const std::string& def_name) -> bool {
bool is_shadow_value_in_subgraph = true;
for (const auto& implicit_input : subgraph_parent_node.ImplicitInputDefs()) {
if (implicit_input->Name() == def_name) {
is_shadow_value_in_subgraph = false;
break;
}
}
return is_shadow_value_in_subgraph;
};
// Skip processing shadow values in subgraphs
if (graph_depth > 0) {
// We are processing a subgraph if we enter this
const auto* parent_node = graph_viewer.ParentNode();
// Skip processing if it is a shadow value
if (is_shadow_value_in_subgraph(*parent_node, def_name)) {
continue;
}
}
auto wt_index = Index(def_name);
locations[wt_index].emplace_back(
GetLocationForNodeInput(node_input_index, node, kernel_create_info_map));
}
// If the node has subgraphs (i.e.) control flow nodes,
// walk the nodes in those subgraphs as well to best determine
// the location for the OrtValue corresponding to the weights
// (i.e.) do a recursion
if (node.ContainsSubgraph()) {
// A node may contain multiple subgraphs - so iterate through all of them
for (auto& name_to_subgraph : node.GetAttributeNameToSubgraphMap()) {
GraphViewer subgraph_viewer(*name_to_subgraph.second);
const auto& local_subgraph_kernel_create_info_map_key =
NestedSubgraphInfoDetails::ComposeNestedSubgraphInfoKeyHelper(subgraph_kernel_create_info_map_key_base,
graph_depth, node.Index(), name_to_subgraph.first);
auto specific_subgraph_kernel_create_info_map = subgraphs_kernel_create_info_maps_.find(local_subgraph_kernel_create_info_map_key);
ORT_ENFORCE(specific_subgraph_kernel_create_info_map != subgraphs_kernel_create_info_maps_.end());
GeneratePlanForWeightsHelper(subgraph_viewer,
weights,
specific_subgraph_kernel_create_info_map->second,
local_subgraph_kernel_create_info_map_key,
graph_depth + 1,
locations);
}
}
}
}
Status GeneratePlanForWeights() {
std::vector<std::vector<OrtMemoryInfo>> locations(plan_.allocation_plan.size());
GeneratePlanForWeightsHelper(graph_viewer_, graph_viewer_.GetAllInitializedTensors(),
kernel_create_info_map_, "", 0, locations);
for (size_t i = 0; i != locations.size(); ++i) {
const std::vector<OrtMemoryInfo>& loc = locations[i];
if (loc.empty()) continue;
@ -671,7 +762,7 @@ class PlannerImpl {
// Should only be used after ProcessDef()
Status ComputeReusePlan() {
std::vector<SequentialExecutionPlan::NodeExecutionPlan>& execution_plan(plan_.execution_plan);
//copy the usecounts to an vector, before computing reuse
//copy the use counts to a vector, before computing reuse
#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
std::vector<int> ort_value_usecount;
for (auto ort_value_info : ort_value_info_) {
@ -1049,7 +1140,7 @@ class PlannerImpl {
}
}
#endif
}; // namespace onnxruntime
};
Status PlannerImpl::CreatePlan() {
auto& p_graph_nodes = graph_viewer_.GetNodesInTopologicalOrder(context_.GetExecutionOrder());
@ -1098,7 +1189,8 @@ Status SequentialPlanner::CreatePlan(
const onnxruntime::GraphViewer& graph_viewer,
const std::vector<const NodeArg*>& outer_scope_node_args,
const ExecutionProviders& providers,
const std::unordered_map<NodeIndex, gsl::not_null<const KernelCreateInfo*>>& kernel_create_info_map,
const KernelCreateInfoMap& kernel_create_info_map,
const SubgraphsKernelCreateInfoMaps& subgraphs_kernel_create_info_maps,
const std::unordered_map<OrtValueName, OrtMemoryInfo>& outer_scope_node_arg_to_location_map,
const OrtValueNameIdxMap& ort_value_name_idx_map,
const ISequentialPlannerContext& context,
@ -1107,7 +1199,8 @@ Status SequentialPlanner::CreatePlan(
plan = std::make_unique<SequentialExecutionPlan>();
PlannerImpl planner(parent_node, graph_viewer, outer_scope_node_args, providers,
kernel_create_info_map, outer_scope_node_arg_to_location_map,
kernel_create_info_map, subgraphs_kernel_create_info_maps,
outer_scope_node_arg_to_location_map,
ort_value_name_idx_map, context, *plan);
return planner.CreatePlan();

View file

@ -15,11 +15,23 @@ class TensorShapeProto;
}
namespace onnxruntime {
namespace NestedSubgraphInfoDetails {
// Used to compose a unique key to identify a nested subgraph
// relative to a current graph level (which in turn is identified using a "base")
std::string ComposeNestedSubgraphInfoKeyHelper(const std::string& base, size_t graph_depth,
NodeIndex node_index, const std::string& attr_name);
} // namespace NestedSubgraphInfoDetails
class ExecutionProviders;
struct KernelCreateInfo;
class KernelRegistryManager;
class OrtValueNameIdxMap;
using KernelCreateInfoMap = std::unordered_map<onnxruntime::NodeIndex, gsl::not_null<const KernelCreateInfo*>>;
using SubgraphsKernelCreateInfoMaps = std::unordered_map<std::string, KernelCreateInfoMap>;
// ISequentialPlannerContext abstracts how the planner accesses information (such as inferred shape)
// to do the planning.
class ISequentialPlannerContext {
@ -65,7 +77,8 @@ class SequentialPlanner {
const Node* parent_node, const onnxruntime::GraphViewer& graph,
const std::vector<const NodeArg*>& outer_scope_node_args,
const ExecutionProviders& providers,
const std::unordered_map<NodeIndex, gsl::not_null<const KernelCreateInfo*>>& kernel_create_info_map,
const KernelCreateInfoMap& kernel_create_info_map,
const SubgraphsKernelCreateInfoMaps& subgraphs_kernel_create_info_maps,
const std::unordered_map<OrtValueName, OrtMemoryInfo>& outer_scope_arg_to_location_map,
const OrtValueNameIdxMap& ort_value_name_idx_map,
const ISequentialPlannerContext& context,

View file

@ -1174,6 +1174,57 @@ static Status OuterScopeNodeArgLocationAccumulator(const SequentialExecutionPlan
return Status::OK();
}
// We accumulate all nested subgraph(s) kernel create info maps relative to the current depth
// (i.e.) if we were on the first nested subgraph, we accumulate information from ALL the
// nested subgraphs within it.
// This information is necessary to plan the right location for initializers
// in a given level because they could be used in one of the nested subgraphs relative to the
// current level (not just within the same level or even one level deep).
// Since we need to package up information from multiple levels of nested subgraphs, the key we use
// is "{key_for_node_containing_subgraph} + current_depth + node_index_containing_the_subgraph + attribute_name".
// {key_for_node_containing_subgraph} is empty for the main graph.
// For example, if we want to store information corresponding to a nested subgraph wrt to the main graph and
// the node index of the node in the main graph was 2 and the attribute containing the specific
// subgraph was "then_branch", the key would be depth + node_index + attribute = 0 + 2 + then_branch
// = "02then_branch".
// If that subgraph contained another subgraph at node index 1, then the key would be,
// {02then_branch} + 1 + 1 + "then_branch" = "02then_branch11then_branch".
static void AccumulateAllNestedSubgraphsInfo(
const SessionState& session_state,
const std::string& subgraph_kernel_create_info_map_key_base,
size_t graph_depth,
/*out*/ SubgraphsKernelCreateInfoMaps& subgraphs_kernel_create_info_maps) {
for (const auto& entry : session_state.GetSubgraphSessionStateMap()) {
auto node_index = entry.first;
for (const auto& name_to_subgraph_session_state : entry.second) {
const auto& subgraph_attr_name = name_to_subgraph_session_state.first;
SessionState& subgraph_session_state = *name_to_subgraph_session_state.second;
const auto& local_subgraph_kernel_create_info_map_key =
NestedSubgraphInfoDetails::ComposeNestedSubgraphInfoKeyHelper(subgraph_kernel_create_info_map_key_base,
graph_depth, node_index, subgraph_attr_name);
// The end user is never likely to see an error with the following line.
// Points to an internal processing error if we hit this.
ORT_ENFORCE(subgraphs_kernel_create_info_maps.find(local_subgraph_kernel_create_info_map_key) ==
subgraphs_kernel_create_info_maps.end());
subgraphs_kernel_create_info_maps.insert({local_subgraph_kernel_create_info_map_key,
subgraph_session_state.GetKernelCreateInfoMap()});
// Recurse into the subgraph session state
AccumulateAllNestedSubgraphsInfo(subgraph_session_state,
local_subgraph_kernel_create_info_map_key,
graph_depth + 1, subgraphs_kernel_create_info_maps);
}
}
}
Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_TYPE>& graph_location,
KernelRegistryManager& kernel_registry_manager,
_In_opt_ const Node* parent_node,
@ -1201,9 +1252,13 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_
});
}
SubgraphsKernelCreateInfoMaps subgraphs_kernel_create_info_maps;
AccumulateAllNestedSubgraphsInfo(*this, "", 0, subgraphs_kernel_create_info_maps);
SequentialPlannerContext context(session_options.execution_mode, session_options.execution_order, session_options.enable_mem_reuse);
ORT_RETURN_IF_ERROR(SequentialPlanner::CreatePlan(parent_node, *graph_viewer_, valid_outer_scope_node_args,
execution_providers_, kernel_create_info_map_,
subgraphs_kernel_create_info_maps,
outer_scope_node_arg_to_location_map,
ort_value_name_idx_map_, context, p_seq_exec_plan_));
//Record the allocation plan

View file

@ -77,6 +77,12 @@ class MemoryInfo;
* Then you can use:
* s.GetKernel(...);
*/
// subgraph SessionState. entry for node containing subgraph, with value containing attribute:SessionState pair
// as a node may contain multiple subgraphs (e.g. 'If' has one for both the 'then' and 'else' branches).
using SubgraphSessionStateMap =
std::unordered_map<onnxruntime::NodeIndex, std::unordered_map<std::string, std::unique_ptr<SessionState>>>;
class SessionState {
public:
SessionState(Graph& graph,
@ -319,6 +325,14 @@ class SessionState {
return used_shared_pre_packed_weights_counter_;
}
const KernelCreateInfoMap& GetKernelCreateInfoMap() const {
return kernel_create_info_map_;
}
const SubgraphSessionStateMap& GetSubgraphSessionStateMap() const {
return subgraph_session_states_;
}
#ifdef DEBUG_NODE_INPUTS_OUTPUTS
void IncrementGraphExecutionCounter() {
++graph_executions_counter_;
@ -385,7 +399,7 @@ class SessionState {
}
// KernelCreateInfo for each node so we do kernel lookup once
std::unordered_map<NodeIndex, gsl::not_null<const KernelCreateInfo*>> kernel_create_info_map_;
KernelCreateInfoMap kernel_create_info_map_;
// If we compile kernels in a minimal build we need a way to find the kernel using the hash.
// We populate this map when doing the kernel compilation in GraphPartitioner, and use it in LoadFromOrtFormat.
@ -469,10 +483,6 @@ class SessionState {
NameNodeInfoMapType input_names_to_nodeinfo_mapping_;
NameNodeInfoMapType output_names_to_nodeinfo_mapping_;
// subgraph SessionState. entry for node containing subgraph, with value containing attribute:SessionState pair
// as a node may contain multiple subgraphs (e.g. 'If' has one for both the 'then' and 'else' branches).
using SubgraphSessionStateMap =
std::unordered_map<onnxruntime::NodeIndex, std::unordered_map<std::string, std::unique_ptr<SessionState>>>;
SubgraphSessionStateMap subgraph_session_states_;
// either threadpool could be nullptr

View file

@ -951,6 +951,14 @@ std::vector<gsl::not_null<const Graph*>> Node::GetSubgraphs() const {
return subgraphs;
}
std::unordered_map<std::string, gsl::not_null<const Graph*>> Node::GetAttributeNameToSubgraphMap() const {
std::unordered_map<std::string, gsl::not_null<const Graph*>> attr_to_subgraphs;
for (auto& entry : attr_to_subgraph_map_) {
attr_to_subgraphs.insert({entry.first, entry.second});
}
return attr_to_subgraphs;
}
void Node::ForEachDef(std::function<void(const onnxruntime::NodeArg&, bool is_input)> func,
bool include_missing_optional_defs) const {
for (const auto* arg : InputDefs()) {

View file

@ -26,7 +26,7 @@ using namespace ONNX_NAMESPACE;
// Explicitly provide a definition for the static const var 'GPU' in the OrtDevice struct,
// GCC 4.x doesn't seem to define this and it breaks the pipelines based on CentOS as it uses
// GCC 4.x.
// (This static var is referenced in `PassThroughExplicitAndImplicitSubgraphInputs` test)
// (This static var is referenced in some tests below)
const OrtDevice::DeviceType OrtDevice::GPU;
namespace onnxruntime {
@ -278,7 +278,7 @@ class PlannerTest : public ::testing::Test {
SequentialPlannerTestContext test_context(&shape_map_);
status = SequentialPlanner::CreatePlan(nullptr, GraphViewer(graph_), outer_scope_node_args, execution_providers_,
kernel_create_info_map, {}, state_->GetOrtValueNameIdxMap(), test_context,
kernel_create_info_map, {}, {}, state_->GetOrtValueNameIdxMap(), test_context,
plan_);
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
@ -514,7 +514,7 @@ TEST_F(PlannerTest, PlanOutputTest) {
}
#ifdef USE_CUDA
TEST_F(PlannerTest, PassThroughExplicitAndImplicitSubgraphInputs) {
TEST_F(PlannerTest, LocationPlanningForPassThroughExplicitAndImplicitSubgraphInputs) {
// Types
TypeProto float_tensor;
float_tensor.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
@ -573,7 +573,6 @@ TEST_F(PlannerTest, PassThroughExplicitAndImplicitSubgraphInputs) {
std::vector<NodeArg*> outputs;
/* Inputs: iter_num, cond_in, loop carried state variables.
iter_num_in cond_in [loop_state_var]
(unused) | |
[Identity] [If]
@ -729,7 +728,105 @@ TEST_F(PlannerTest, PassThroughExplicitAndImplicitSubgraphInputs) {
EXPECT_EQ(first_subgraph_plan->allocation_plan[abs_data_1_out_index].location.device.Type(), OrtDevice::GPU);
}
}
#endif
TEST_F(PlannerTest, LocationPlanningForInitializersOnlyUsedInANestedSubgraph) {
// This a simple model that has one outer scope initializer and an `If` node
// and that initializer is ONLY used in nested subgraphs (both the `If` subgraphs).
// We want to test that the location planned for this initializer accounts for
// its usage in the nested subgraphs and statically determines the right location
// for it (without defaulting to CPU).
// Types
TypeProto float_tensor;
float_tensor.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_param("dim_param");
TypeProto bool_scalar;
bool_scalar.mutable_tensor_type()->set_elem_type(TensorProto_DataType_BOOL);
bool_scalar.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
auto create_model = [&float_tensor, &bool_scalar]() -> Model {
auto create_if_subgraph = [&float_tensor](bool is_then) -> GraphProto {
Model model("if_branch_subgraph", true, DefaultLoggingManager().DefaultLogger());
auto& graph = model.MainGraph();
auto& outer_scope_0 = graph.GetOrCreateNodeArg("abs_data_out", &float_tensor);
graph.AddOuterScopeNodeArg("abs_data_out");
auto& outer_scope_1 = graph.GetOrCreateNodeArg("init_data", &float_tensor);
graph.AddOuterScopeNodeArg("init_data");
auto& if_out = graph.GetOrCreateNodeArg(is_then ? "if_then_out" : "if_else_out", &float_tensor);
graph.AddNode("if_out", "Add", "add", {&outer_scope_0, &outer_scope_1}, {&if_out});
auto status = graph.Resolve();
EXPECT_EQ(status, Status::OK());
return graph.ToGraphProto();
};
onnxruntime::Model model("main_graph", false, ModelMetaData(),
PathString(), IOnnxRuntimeOpSchemaRegistryList(),
{{kOnnxDomain, 12}}, {}, DefaultLoggingManager().DefaultLogger());
auto& main_graph = model.MainGraph();
// Abs-0
auto& abs_data_in = main_graph.GetOrCreateNodeArg("abs_data_in", &float_tensor);
auto& abs_data_out = main_graph.GetOrCreateNodeArg("abs_data_out", &float_tensor);
main_graph.AddNode("abs_0", "Abs", "node abs", {&abs_data_in}, {&abs_data_out});
// If
auto& if_in = main_graph.GetOrCreateNodeArg("if_in", &bool_scalar);
auto& if_out = main_graph.GetOrCreateNodeArg("if_out", &float_tensor);
auto& node = main_graph.AddNode("if_out", "If", "If", {&if_in}, {&if_out});
node.AddAttribute("then_branch", create_if_subgraph(true));
node.AddAttribute("else_branch", create_if_subgraph(false));
// Add initializer to the graph
ONNX_NAMESPACE::TensorProto tensor;
tensor.add_dims(1);
tensor.add_float_data(1.0f);
tensor.set_data_type(TensorProto_DataType_FLOAT);
tensor.set_name("init_data");
main_graph.AddInitializedTensor(tensor);
// Main graph's inputs/outputs
main_graph.SetInputs({&abs_data_in, &if_in});
main_graph.SetOutputs({&if_out});
auto status = main_graph.Resolve();
EXPECT_EQ(status, Status::OK());
return model;
};
// Create and load session
SessionOptions so;
InferenceSession sess{so, GetEnvironment()};
auto status = sess.RegisterExecutionProvider(DefaultCudaExecutionProvider());
ASSERT_TRUE(status.IsOK());
std::string s1;
const bool rc = create_model().ToProto().SerializeToString(&s1);
EXPECT_EQ(rc, true);
std::stringstream sstr(s1);
status = sess.Load(sstr);
ASSERT_TRUE(status.IsOK());
status = sess.Initialize();
ASSERT_TRUE(status.IsOK());
// Check planned locations for the initializer
const auto& main_graph_session_state = sess.GetSessionState();
const auto& main_graph_ort_value_index_map = main_graph_session_state.GetOrtValueNameIdxMap();
const auto* main_graph_plan = main_graph_session_state.GetExecutionPlan();
OrtValueIndex init_data_index;
main_graph_ort_value_index_map.GetIdx("init_data", init_data_index);
EXPECT_EQ(main_graph_plan->allocation_plan[init_data_index].location.device.Type(), OrtDevice::GPU);
}
#endif
} // namespace test
} // namespace onnxruntime

View file

@ -29,7 +29,7 @@ python3 /onnxruntime_src/tools/ci_build/build.py \
--include_ops_by_config /home/onnxruntimedev/.test_data/include_no_operators.config
# set current size limit to BINARY_SIZE_LIMIT_IN_BYTES.
BINARY_SIZE_LIMIT_IN_BYTES=1255000
BINARY_SIZE_LIMIT_IN_BYTES=1256000
echo "The current preset binary size limit is $BINARY_SIZE_LIMIT_IN_BYTES"
python3 /onnxruntime_src/tools/ci_build/github/linux/ort_minimal/check_build_binary_size.py \
--threshold=$BINARY_SIZE_LIMIT_IN_BYTES \