mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-27 22:45:57 +00:00
Fix location planning for initializers used only in nested subgraphs (#8642)
This commit is contained in:
parent
4dc0ddf606
commit
acd9db7fad
8 changed files with 324 additions and 42 deletions
|
|
@ -381,6 +381,12 @@ class Node {
|
|||
return attr_to_subgraph_map_;
|
||||
}
|
||||
|
||||
/** Gets a map of attribute name to the const Graph instances for all subgraphs of the Node.
|
||||
@returns Map of the attribute name that defines the subgraph to the subgraph's Graph instance.
|
||||
nullptr if the Node has no subgraphs.
|
||||
*/
|
||||
std::unordered_map<std::string, gsl::not_null<const Graph*>> GetAttributeNameToSubgraphMap() const;
|
||||
|
||||
/** Gets the execution ProviderType that this node will be executed by. */
|
||||
ProviderType GetExecutionProviderType() const noexcept { return execution_provider_type_; }
|
||||
|
||||
|
|
|
|||
|
|
@ -20,6 +20,27 @@ using namespace onnxruntime::common;
|
|||
using namespace ONNX_NAMESPACE;
|
||||
namespace onnxruntime {
|
||||
|
||||
namespace NestedSubgraphInfoDetails {
|
||||
|
||||
// Used to compose a unique key to identify a nested subgraph
|
||||
// relative to a current graph level (which in turn is identified using a "base")
|
||||
std::string ComposeNestedSubgraphInfoKeyHelper(const std::string& base,
|
||||
size_t graph_depth,
|
||||
NodeIndex node_index,
|
||||
const std::string& attr_name) {
|
||||
std::ostringstream ss;
|
||||
|
||||
// key = base + graph depth + current graph node index + attr name corresponding to the subgraph
|
||||
ss << base;
|
||||
ss << graph_depth;
|
||||
ss << node_index;
|
||||
ss << attr_name;
|
||||
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
} // namespace NestedSubgraphInfoDetails
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, AllocKind alloc_kind) {
|
||||
switch (alloc_kind) {
|
||||
case AllocKind::kAllocate:
|
||||
|
|
@ -107,7 +128,7 @@ std::ostream& operator<<(std::ostream& out, std::pair<const SequentialExecutionP
|
|||
}
|
||||
|
||||
static const KernelCreateInfo& GetKernelCreateInfo(
|
||||
const std::unordered_map<NodeIndex, gsl::not_null<const KernelCreateInfo*>>& kernel_create_info_map,
|
||||
const KernelCreateInfoMap& kernel_create_info_map,
|
||||
NodeIndex node_index) {
|
||||
auto entry = kernel_create_info_map.find(node_index);
|
||||
ORT_ENFORCE(entry != kernel_create_info_map.cend(),
|
||||
|
|
@ -120,7 +141,8 @@ class PlannerImpl {
|
|||
public:
|
||||
PlannerImpl(const Node* parent_node, const onnxruntime::GraphViewer& graph_viewer,
|
||||
const std::vector<const NodeArg*>& outer_scope_node_args, const ExecutionProviders& providers,
|
||||
const std::unordered_map<NodeIndex, gsl::not_null<const KernelCreateInfo*>>& kernel_create_info_map,
|
||||
const KernelCreateInfoMap& kernel_create_info_map,
|
||||
const SubgraphsKernelCreateInfoMaps& subgraphs_kernel_create_info_maps,
|
||||
const std::unordered_map<OrtValueName, OrtMemoryInfo>& outer_scope_node_arg_to_location_map,
|
||||
const OrtValueNameIdxMap& ort_value_name_idx_map,
|
||||
const ISequentialPlannerContext& context, SequentialExecutionPlan& plan)
|
||||
|
|
@ -131,6 +153,7 @@ class PlannerImpl {
|
|||
outer_scope_node_args_(outer_scope_node_args),
|
||||
execution_providers_(providers),
|
||||
kernel_create_info_map_(kernel_create_info_map),
|
||||
subgraphs_kernel_create_info_maps_(subgraphs_kernel_create_info_maps),
|
||||
outer_scope_node_arg_to_location_map_(outer_scope_node_arg_to_location_map),
|
||||
ort_value_name_idx_map_(ort_value_name_idx_map) {}
|
||||
|
||||
|
|
@ -145,7 +168,8 @@ class PlannerImpl {
|
|||
const std::vector<const NodeArg*>& outer_scope_node_args_;
|
||||
const ExecutionProviders& execution_providers_;
|
||||
|
||||
const std::unordered_map<NodeIndex, gsl::not_null<const KernelCreateInfo*>>& kernel_create_info_map_;
|
||||
const KernelCreateInfoMap& kernel_create_info_map_;
|
||||
const SubgraphsKernelCreateInfoMaps& subgraphs_kernel_create_info_maps_;
|
||||
|
||||
const std::unordered_map<OrtValueName, OrtMemoryInfo>& outer_scope_node_arg_to_location_map_;
|
||||
|
||||
|
|
@ -609,11 +633,12 @@ class PlannerImpl {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
OrtMemoryInfo GetLocationForNodeInput(size_t input_index, const Node& node) {
|
||||
OrtMemoryInfo GetLocationForNodeInput(size_t input_index, const Node& node,
|
||||
const KernelCreateInfoMap& kernel_create_info_map) {
|
||||
auto* p_provider = execution_providers_.Get(node);
|
||||
ORT_ENFORCE(p_provider);
|
||||
|
||||
const KernelCreateInfo& kernel_create_info = GetKernelCreateInfo(kernel_create_info_map_, node.Index());
|
||||
const KernelCreateInfo& kernel_create_info = GetKernelCreateInfo(kernel_create_info_map, node.Index());
|
||||
|
||||
if (utils::IsInputOnCpu(node, &kernel_create_info, input_index))
|
||||
// weights are not output from any node, so it's OK to put its location on CPU provider
|
||||
|
|
@ -621,29 +646,95 @@ class PlannerImpl {
|
|||
return p_provider->GetAllocator(0, OrtMemTypeDefault)->Info();
|
||||
}
|
||||
|
||||
Status GeneratePlanForWeights() {
|
||||
auto& weights = graph_viewer_.GetAllInitializedTensors();
|
||||
std::vector<std::vector<OrtMemoryInfo>> locations(plan_.allocation_plan.size());
|
||||
for (const auto& node : graph_viewer_.Nodes()) {
|
||||
auto status = onnxruntime::Node::ForEachWithIndex(
|
||||
node.InputDefs(), [this, &locations, &node, &weights](const onnxruntime::NodeArg& def, size_t index) {
|
||||
auto sub_status = Status::OK();
|
||||
ORT_TRY {
|
||||
auto& def_name = def.Name();
|
||||
if (!weights.count(def_name)) return Status::OK();
|
||||
auto wt_index = Index(def_name);
|
||||
locations[wt_index].emplace_back(GetLocationForNodeInput(index, node));
|
||||
}
|
||||
ORT_CATCH(const std::exception& ex) {
|
||||
ORT_HANDLE_EXCEPTION([&]() {
|
||||
sub_status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, ex.what());
|
||||
});
|
||||
}
|
||||
return sub_status;
|
||||
});
|
||||
void GeneratePlanForWeightsHelper(const GraphViewer& graph_viewer,
|
||||
const InitializedTensorSet& weights,
|
||||
const KernelCreateInfoMap& kernel_create_info_map,
|
||||
const std::string& subgraph_kernel_create_info_map_key_base,
|
||||
size_t graph_depth,
|
||||
/*out*/ std::vector<std::vector<OrtMemoryInfo>>& locations) {
|
||||
for (const auto& node : graph_viewer.Nodes()) {
|
||||
const auto& input_node_args = node.InputDefs();
|
||||
size_t num_node_inputs = input_node_args.size();
|
||||
|
||||
ORT_RETURN_IF_ERROR(status);
|
||||
for (size_t node_input_index = 0; node_input_index < num_node_inputs; ++node_input_index) {
|
||||
auto input_node_arg = input_node_args[node_input_index];
|
||||
|
||||
// Skip processing missing optional inputs
|
||||
if (!input_node_arg->Exists()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto& def_name = input_node_arg->Name();
|
||||
|
||||
// This node input doesn't correspond to any of the weights
|
||||
if (!weights.count(def_name)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// While processing subgraphs, if we don't see an entry in the implicit
|
||||
// inputs of the node containing the subgraph, it is a shadow value.
|
||||
auto is_shadow_value_in_subgraph = [](const Node& subgraph_parent_node,
|
||||
const std::string& def_name) -> bool {
|
||||
bool is_shadow_value_in_subgraph = true;
|
||||
for (const auto& implicit_input : subgraph_parent_node.ImplicitInputDefs()) {
|
||||
if (implicit_input->Name() == def_name) {
|
||||
is_shadow_value_in_subgraph = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return is_shadow_value_in_subgraph;
|
||||
};
|
||||
|
||||
// Skip processing shadow values in subgraphs
|
||||
if (graph_depth > 0) {
|
||||
// We are processing a subgraph if we enter this
|
||||
const auto* parent_node = graph_viewer.ParentNode();
|
||||
|
||||
// Skip processing if it is a shadow value
|
||||
if (is_shadow_value_in_subgraph(*parent_node, def_name)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
auto wt_index = Index(def_name);
|
||||
locations[wt_index].emplace_back(
|
||||
GetLocationForNodeInput(node_input_index, node, kernel_create_info_map));
|
||||
}
|
||||
|
||||
// If the node has subgraphs (i.e.) control flow nodes,
|
||||
// walk the nodes in those subgraphs as well to best determine
|
||||
// the location for the OrtValue corresponding to the weights
|
||||
// (i.e.) do a recursion
|
||||
if (node.ContainsSubgraph()) {
|
||||
// A node may contain multiple subgraphs - so iterate through all of them
|
||||
for (auto& name_to_subgraph : node.GetAttributeNameToSubgraphMap()) {
|
||||
GraphViewer subgraph_viewer(*name_to_subgraph.second);
|
||||
|
||||
const auto& local_subgraph_kernel_create_info_map_key =
|
||||
NestedSubgraphInfoDetails::ComposeNestedSubgraphInfoKeyHelper(subgraph_kernel_create_info_map_key_base,
|
||||
graph_depth, node.Index(), name_to_subgraph.first);
|
||||
|
||||
auto specific_subgraph_kernel_create_info_map = subgraphs_kernel_create_info_maps_.find(local_subgraph_kernel_create_info_map_key);
|
||||
ORT_ENFORCE(specific_subgraph_kernel_create_info_map != subgraphs_kernel_create_info_maps_.end());
|
||||
|
||||
GeneratePlanForWeightsHelper(subgraph_viewer,
|
||||
weights,
|
||||
specific_subgraph_kernel_create_info_map->second,
|
||||
local_subgraph_kernel_create_info_map_key,
|
||||
graph_depth + 1,
|
||||
locations);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Status GeneratePlanForWeights() {
|
||||
std::vector<std::vector<OrtMemoryInfo>> locations(plan_.allocation_plan.size());
|
||||
|
||||
GeneratePlanForWeightsHelper(graph_viewer_, graph_viewer_.GetAllInitializedTensors(),
|
||||
kernel_create_info_map_, "", 0, locations);
|
||||
|
||||
for (size_t i = 0; i != locations.size(); ++i) {
|
||||
const std::vector<OrtMemoryInfo>& loc = locations[i];
|
||||
if (loc.empty()) continue;
|
||||
|
|
@ -671,7 +762,7 @@ class PlannerImpl {
|
|||
// Should only be used after ProcessDef()
|
||||
Status ComputeReusePlan() {
|
||||
std::vector<SequentialExecutionPlan::NodeExecutionPlan>& execution_plan(plan_.execution_plan);
|
||||
//copy the usecounts to an vector, before computing reuse
|
||||
//copy the use counts to a vector, before computing reuse
|
||||
#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
|
||||
std::vector<int> ort_value_usecount;
|
||||
for (auto ort_value_info : ort_value_info_) {
|
||||
|
|
@ -1049,7 +1140,7 @@ class PlannerImpl {
|
|||
}
|
||||
}
|
||||
#endif
|
||||
}; // namespace onnxruntime
|
||||
};
|
||||
|
||||
Status PlannerImpl::CreatePlan() {
|
||||
auto& p_graph_nodes = graph_viewer_.GetNodesInTopologicalOrder(context_.GetExecutionOrder());
|
||||
|
|
@ -1098,7 +1189,8 @@ Status SequentialPlanner::CreatePlan(
|
|||
const onnxruntime::GraphViewer& graph_viewer,
|
||||
const std::vector<const NodeArg*>& outer_scope_node_args,
|
||||
const ExecutionProviders& providers,
|
||||
const std::unordered_map<NodeIndex, gsl::not_null<const KernelCreateInfo*>>& kernel_create_info_map,
|
||||
const KernelCreateInfoMap& kernel_create_info_map,
|
||||
const SubgraphsKernelCreateInfoMaps& subgraphs_kernel_create_info_maps,
|
||||
const std::unordered_map<OrtValueName, OrtMemoryInfo>& outer_scope_node_arg_to_location_map,
|
||||
const OrtValueNameIdxMap& ort_value_name_idx_map,
|
||||
const ISequentialPlannerContext& context,
|
||||
|
|
@ -1107,7 +1199,8 @@ Status SequentialPlanner::CreatePlan(
|
|||
plan = std::make_unique<SequentialExecutionPlan>();
|
||||
|
||||
PlannerImpl planner(parent_node, graph_viewer, outer_scope_node_args, providers,
|
||||
kernel_create_info_map, outer_scope_node_arg_to_location_map,
|
||||
kernel_create_info_map, subgraphs_kernel_create_info_maps,
|
||||
outer_scope_node_arg_to_location_map,
|
||||
ort_value_name_idx_map, context, *plan);
|
||||
|
||||
return planner.CreatePlan();
|
||||
|
|
|
|||
|
|
@ -15,11 +15,23 @@ class TensorShapeProto;
|
|||
}
|
||||
namespace onnxruntime {
|
||||
|
||||
namespace NestedSubgraphInfoDetails {
|
||||
|
||||
// Used to compose a unique key to identify a nested subgraph
|
||||
// relative to a current graph level (which in turn is identified using a "base")
|
||||
std::string ComposeNestedSubgraphInfoKeyHelper(const std::string& base, size_t graph_depth,
|
||||
NodeIndex node_index, const std::string& attr_name);
|
||||
|
||||
} // namespace NestedSubgraphInfoDetails
|
||||
|
||||
class ExecutionProviders;
|
||||
struct KernelCreateInfo;
|
||||
class KernelRegistryManager;
|
||||
class OrtValueNameIdxMap;
|
||||
|
||||
using KernelCreateInfoMap = std::unordered_map<onnxruntime::NodeIndex, gsl::not_null<const KernelCreateInfo*>>;
|
||||
using SubgraphsKernelCreateInfoMaps = std::unordered_map<std::string, KernelCreateInfoMap>;
|
||||
|
||||
// ISequentialPlannerContext abstracts how the planner accesses information (such as inferred shape)
|
||||
// to do the planning.
|
||||
class ISequentialPlannerContext {
|
||||
|
|
@ -65,7 +77,8 @@ class SequentialPlanner {
|
|||
const Node* parent_node, const onnxruntime::GraphViewer& graph,
|
||||
const std::vector<const NodeArg*>& outer_scope_node_args,
|
||||
const ExecutionProviders& providers,
|
||||
const std::unordered_map<NodeIndex, gsl::not_null<const KernelCreateInfo*>>& kernel_create_info_map,
|
||||
const KernelCreateInfoMap& kernel_create_info_map,
|
||||
const SubgraphsKernelCreateInfoMaps& subgraphs_kernel_create_info_maps,
|
||||
const std::unordered_map<OrtValueName, OrtMemoryInfo>& outer_scope_arg_to_location_map,
|
||||
const OrtValueNameIdxMap& ort_value_name_idx_map,
|
||||
const ISequentialPlannerContext& context,
|
||||
|
|
|
|||
|
|
@ -1174,6 +1174,57 @@ static Status OuterScopeNodeArgLocationAccumulator(const SequentialExecutionPlan
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// We accumulate all nested subgraph(s) kernel create info maps relative to the current depth
|
||||
// (i.e.) if we were on the first nested subgraph, we accumulate information from ALL the
|
||||
// nested subgraphs within it.
|
||||
// This information is necessary to plan the right location for initializers
|
||||
// in a given level because they could be used in one of the nested subgraphs relative to the
|
||||
// current level (not just within the same level or even one level deep).
|
||||
// Since we need to package up information from multiple levels of nested subgraphs, the key we use
|
||||
// is "{key_for_node_containing_subgraph} + current_depth + node_index_containing_the_subgraph + attribute_name".
|
||||
// {key_for_node_containing_subgraph} is empty for the main graph.
|
||||
|
||||
// For example, if we want to store information corresponding to a nested subgraph wrt to the main graph and
|
||||
// the node index of the node in the main graph was 2 and the attribute containing the specific
|
||||
// subgraph was "then_branch", the key would be depth + node_index + attribute = 0 + 2 + then_branch
|
||||
// = "02then_branch".
|
||||
|
||||
// If that subgraph contained another subgraph at node index 1, then the key would be,
|
||||
// {02then_branch} + 1 + 1 + "then_branch" = "02then_branch11then_branch".
|
||||
|
||||
static void AccumulateAllNestedSubgraphsInfo(
|
||||
const SessionState& session_state,
|
||||
const std::string& subgraph_kernel_create_info_map_key_base,
|
||||
size_t graph_depth,
|
||||
/*out*/ SubgraphsKernelCreateInfoMaps& subgraphs_kernel_create_info_maps) {
|
||||
for (const auto& entry : session_state.GetSubgraphSessionStateMap()) {
|
||||
auto node_index = entry.first;
|
||||
|
||||
for (const auto& name_to_subgraph_session_state : entry.second) {
|
||||
const auto& subgraph_attr_name = name_to_subgraph_session_state.first;
|
||||
|
||||
SessionState& subgraph_session_state = *name_to_subgraph_session_state.second;
|
||||
|
||||
const auto& local_subgraph_kernel_create_info_map_key =
|
||||
NestedSubgraphInfoDetails::ComposeNestedSubgraphInfoKeyHelper(subgraph_kernel_create_info_map_key_base,
|
||||
graph_depth, node_index, subgraph_attr_name);
|
||||
|
||||
// The end user is never likely to see an error with the following line.
|
||||
// Points to an internal processing error if we hit this.
|
||||
ORT_ENFORCE(subgraphs_kernel_create_info_maps.find(local_subgraph_kernel_create_info_map_key) ==
|
||||
subgraphs_kernel_create_info_maps.end());
|
||||
|
||||
subgraphs_kernel_create_info_maps.insert({local_subgraph_kernel_create_info_map_key,
|
||||
subgraph_session_state.GetKernelCreateInfoMap()});
|
||||
|
||||
// Recurse into the subgraph session state
|
||||
AccumulateAllNestedSubgraphsInfo(subgraph_session_state,
|
||||
local_subgraph_kernel_create_info_map_key,
|
||||
graph_depth + 1, subgraphs_kernel_create_info_maps);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_TYPE>& graph_location,
|
||||
KernelRegistryManager& kernel_registry_manager,
|
||||
_In_opt_ const Node* parent_node,
|
||||
|
|
@ -1201,9 +1252,13 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_
|
|||
});
|
||||
}
|
||||
|
||||
SubgraphsKernelCreateInfoMaps subgraphs_kernel_create_info_maps;
|
||||
AccumulateAllNestedSubgraphsInfo(*this, "", 0, subgraphs_kernel_create_info_maps);
|
||||
|
||||
SequentialPlannerContext context(session_options.execution_mode, session_options.execution_order, session_options.enable_mem_reuse);
|
||||
ORT_RETURN_IF_ERROR(SequentialPlanner::CreatePlan(parent_node, *graph_viewer_, valid_outer_scope_node_args,
|
||||
execution_providers_, kernel_create_info_map_,
|
||||
subgraphs_kernel_create_info_maps,
|
||||
outer_scope_node_arg_to_location_map,
|
||||
ort_value_name_idx_map_, context, p_seq_exec_plan_));
|
||||
//Record the allocation plan
|
||||
|
|
|
|||
|
|
@ -77,6 +77,12 @@ class MemoryInfo;
|
|||
* Then you can use:
|
||||
* s.GetKernel(...);
|
||||
*/
|
||||
|
||||
// subgraph SessionState. entry for node containing subgraph, with value containing attribute:SessionState pair
|
||||
// as a node may contain multiple subgraphs (e.g. 'If' has one for both the 'then' and 'else' branches).
|
||||
using SubgraphSessionStateMap =
|
||||
std::unordered_map<onnxruntime::NodeIndex, std::unordered_map<std::string, std::unique_ptr<SessionState>>>;
|
||||
|
||||
class SessionState {
|
||||
public:
|
||||
SessionState(Graph& graph,
|
||||
|
|
@ -319,6 +325,14 @@ class SessionState {
|
|||
return used_shared_pre_packed_weights_counter_;
|
||||
}
|
||||
|
||||
const KernelCreateInfoMap& GetKernelCreateInfoMap() const {
|
||||
return kernel_create_info_map_;
|
||||
}
|
||||
|
||||
const SubgraphSessionStateMap& GetSubgraphSessionStateMap() const {
|
||||
return subgraph_session_states_;
|
||||
}
|
||||
|
||||
#ifdef DEBUG_NODE_INPUTS_OUTPUTS
|
||||
void IncrementGraphExecutionCounter() {
|
||||
++graph_executions_counter_;
|
||||
|
|
@ -385,7 +399,7 @@ class SessionState {
|
|||
}
|
||||
|
||||
// KernelCreateInfo for each node so we do kernel lookup once
|
||||
std::unordered_map<NodeIndex, gsl::not_null<const KernelCreateInfo*>> kernel_create_info_map_;
|
||||
KernelCreateInfoMap kernel_create_info_map_;
|
||||
|
||||
// If we compile kernels in a minimal build we need a way to find the kernel using the hash.
|
||||
// We populate this map when doing the kernel compilation in GraphPartitioner, and use it in LoadFromOrtFormat.
|
||||
|
|
@ -469,10 +483,6 @@ class SessionState {
|
|||
NameNodeInfoMapType input_names_to_nodeinfo_mapping_;
|
||||
NameNodeInfoMapType output_names_to_nodeinfo_mapping_;
|
||||
|
||||
// subgraph SessionState. entry for node containing subgraph, with value containing attribute:SessionState pair
|
||||
// as a node may contain multiple subgraphs (e.g. 'If' has one for both the 'then' and 'else' branches).
|
||||
using SubgraphSessionStateMap =
|
||||
std::unordered_map<onnxruntime::NodeIndex, std::unordered_map<std::string, std::unique_ptr<SessionState>>>;
|
||||
SubgraphSessionStateMap subgraph_session_states_;
|
||||
|
||||
// either threadpool could be nullptr
|
||||
|
|
|
|||
|
|
@ -951,6 +951,14 @@ std::vector<gsl::not_null<const Graph*>> Node::GetSubgraphs() const {
|
|||
return subgraphs;
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, gsl::not_null<const Graph*>> Node::GetAttributeNameToSubgraphMap() const {
|
||||
std::unordered_map<std::string, gsl::not_null<const Graph*>> attr_to_subgraphs;
|
||||
for (auto& entry : attr_to_subgraph_map_) {
|
||||
attr_to_subgraphs.insert({entry.first, entry.second});
|
||||
}
|
||||
return attr_to_subgraphs;
|
||||
}
|
||||
|
||||
void Node::ForEachDef(std::function<void(const onnxruntime::NodeArg&, bool is_input)> func,
|
||||
bool include_missing_optional_defs) const {
|
||||
for (const auto* arg : InputDefs()) {
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ using namespace ONNX_NAMESPACE;
|
|||
// Explicitly provide a definition for the static const var 'GPU' in the OrtDevice struct,
|
||||
// GCC 4.x doesn't seem to define this and it breaks the pipelines based on CentOS as it uses
|
||||
// GCC 4.x.
|
||||
// (This static var is referenced in `PassThroughExplicitAndImplicitSubgraphInputs` test)
|
||||
// (This static var is referenced in some tests below)
|
||||
const OrtDevice::DeviceType OrtDevice::GPU;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
@ -278,7 +278,7 @@ class PlannerTest : public ::testing::Test {
|
|||
SequentialPlannerTestContext test_context(&shape_map_);
|
||||
|
||||
status = SequentialPlanner::CreatePlan(nullptr, GraphViewer(graph_), outer_scope_node_args, execution_providers_,
|
||||
kernel_create_info_map, {}, state_->GetOrtValueNameIdxMap(), test_context,
|
||||
kernel_create_info_map, {}, {}, state_->GetOrtValueNameIdxMap(), test_context,
|
||||
plan_);
|
||||
|
||||
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
|
|
@ -514,7 +514,7 @@ TEST_F(PlannerTest, PlanOutputTest) {
|
|||
}
|
||||
|
||||
#ifdef USE_CUDA
|
||||
TEST_F(PlannerTest, PassThroughExplicitAndImplicitSubgraphInputs) {
|
||||
TEST_F(PlannerTest, LocationPlanningForPassThroughExplicitAndImplicitSubgraphInputs) {
|
||||
// Types
|
||||
TypeProto float_tensor;
|
||||
float_tensor.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
|
||||
|
|
@ -573,7 +573,6 @@ TEST_F(PlannerTest, PassThroughExplicitAndImplicitSubgraphInputs) {
|
|||
std::vector<NodeArg*> outputs;
|
||||
|
||||
/* Inputs: iter_num, cond_in, loop carried state variables.
|
||||
|
||||
iter_num_in cond_in [loop_state_var]
|
||||
(unused) | |
|
||||
[Identity] [If]
|
||||
|
|
@ -729,7 +728,105 @@ TEST_F(PlannerTest, PassThroughExplicitAndImplicitSubgraphInputs) {
|
|||
EXPECT_EQ(first_subgraph_plan->allocation_plan[abs_data_1_out_index].location.device.Type(), OrtDevice::GPU);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
TEST_F(PlannerTest, LocationPlanningForInitializersOnlyUsedInANestedSubgraph) {
|
||||
// This a simple model that has one outer scope initializer and an `If` node
|
||||
// and that initializer is ONLY used in nested subgraphs (both the `If` subgraphs).
|
||||
// We want to test that the location planned for this initializer accounts for
|
||||
// its usage in the nested subgraphs and statically determines the right location
|
||||
// for it (without defaulting to CPU).
|
||||
|
||||
// Types
|
||||
TypeProto float_tensor;
|
||||
float_tensor.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
|
||||
float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_param("dim_param");
|
||||
|
||||
TypeProto bool_scalar;
|
||||
bool_scalar.mutable_tensor_type()->set_elem_type(TensorProto_DataType_BOOL);
|
||||
bool_scalar.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
|
||||
|
||||
auto create_model = [&float_tensor, &bool_scalar]() -> Model {
|
||||
auto create_if_subgraph = [&float_tensor](bool is_then) -> GraphProto {
|
||||
Model model("if_branch_subgraph", true, DefaultLoggingManager().DefaultLogger());
|
||||
auto& graph = model.MainGraph();
|
||||
|
||||
auto& outer_scope_0 = graph.GetOrCreateNodeArg("abs_data_out", &float_tensor);
|
||||
graph.AddOuterScopeNodeArg("abs_data_out");
|
||||
|
||||
auto& outer_scope_1 = graph.GetOrCreateNodeArg("init_data", &float_tensor);
|
||||
graph.AddOuterScopeNodeArg("init_data");
|
||||
|
||||
auto& if_out = graph.GetOrCreateNodeArg(is_then ? "if_then_out" : "if_else_out", &float_tensor);
|
||||
graph.AddNode("if_out", "Add", "add", {&outer_scope_0, &outer_scope_1}, {&if_out});
|
||||
|
||||
auto status = graph.Resolve();
|
||||
EXPECT_EQ(status, Status::OK());
|
||||
|
||||
return graph.ToGraphProto();
|
||||
};
|
||||
|
||||
onnxruntime::Model model("main_graph", false, ModelMetaData(),
|
||||
PathString(), IOnnxRuntimeOpSchemaRegistryList(),
|
||||
{{kOnnxDomain, 12}}, {}, DefaultLoggingManager().DefaultLogger());
|
||||
auto& main_graph = model.MainGraph();
|
||||
|
||||
// Abs-0
|
||||
auto& abs_data_in = main_graph.GetOrCreateNodeArg("abs_data_in", &float_tensor);
|
||||
auto& abs_data_out = main_graph.GetOrCreateNodeArg("abs_data_out", &float_tensor);
|
||||
main_graph.AddNode("abs_0", "Abs", "node abs", {&abs_data_in}, {&abs_data_out});
|
||||
|
||||
// If
|
||||
auto& if_in = main_graph.GetOrCreateNodeArg("if_in", &bool_scalar);
|
||||
auto& if_out = main_graph.GetOrCreateNodeArg("if_out", &float_tensor);
|
||||
auto& node = main_graph.AddNode("if_out", "If", "If", {&if_in}, {&if_out});
|
||||
node.AddAttribute("then_branch", create_if_subgraph(true));
|
||||
node.AddAttribute("else_branch", create_if_subgraph(false));
|
||||
|
||||
// Add initializer to the graph
|
||||
ONNX_NAMESPACE::TensorProto tensor;
|
||||
tensor.add_dims(1);
|
||||
tensor.add_float_data(1.0f);
|
||||
tensor.set_data_type(TensorProto_DataType_FLOAT);
|
||||
tensor.set_name("init_data");
|
||||
main_graph.AddInitializedTensor(tensor);
|
||||
|
||||
// Main graph's inputs/outputs
|
||||
main_graph.SetInputs({&abs_data_in, &if_in});
|
||||
main_graph.SetOutputs({&if_out});
|
||||
|
||||
auto status = main_graph.Resolve();
|
||||
EXPECT_EQ(status, Status::OK());
|
||||
|
||||
return model;
|
||||
};
|
||||
|
||||
// Create and load session
|
||||
SessionOptions so;
|
||||
InferenceSession sess{so, GetEnvironment()};
|
||||
|
||||
auto status = sess.RegisterExecutionProvider(DefaultCudaExecutionProvider());
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
|
||||
std::string s1;
|
||||
const bool rc = create_model().ToProto().SerializeToString(&s1);
|
||||
EXPECT_EQ(rc, true);
|
||||
std::stringstream sstr(s1);
|
||||
|
||||
status = sess.Load(sstr);
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
|
||||
status = sess.Initialize();
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
|
||||
// Check planned locations for the initializer
|
||||
const auto& main_graph_session_state = sess.GetSessionState();
|
||||
const auto& main_graph_ort_value_index_map = main_graph_session_state.GetOrtValueNameIdxMap();
|
||||
const auto* main_graph_plan = main_graph_session_state.GetExecutionPlan();
|
||||
|
||||
OrtValueIndex init_data_index;
|
||||
main_graph_ort_value_index_map.GetIdx("init_data", init_data_index);
|
||||
|
||||
EXPECT_EQ(main_graph_plan->allocation_plan[init_data_index].location.device.Type(), OrtDevice::GPU);
|
||||
}
|
||||
#endif
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ python3 /onnxruntime_src/tools/ci_build/build.py \
|
|||
--include_ops_by_config /home/onnxruntimedev/.test_data/include_no_operators.config
|
||||
|
||||
# set current size limit to BINARY_SIZE_LIMIT_IN_BYTES.
|
||||
BINARY_SIZE_LIMIT_IN_BYTES=1255000
|
||||
BINARY_SIZE_LIMIT_IN_BYTES=1256000
|
||||
echo "The current preset binary size limit is $BINARY_SIZE_LIMIT_IN_BYTES"
|
||||
python3 /onnxruntime_src/tools/ci_build/github/linux/ort_minimal/check_build_binary_size.py \
|
||||
--threshold=$BINARY_SIZE_LIMIT_IN_BYTES \
|
||||
|
|
|
|||
Loading…
Reference in a new issue