mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-30 23:18:20 +00:00
update set fetches for execution with allocation plan. (#1668)
This commit is contained in:
parent
6f70a78e1f
commit
b53f40a886
1 changed files with 8 additions and 14 deletions
|
|
@ -71,7 +71,7 @@ static Status CopyMLValue(const DataTransferManager& data_transfer_mgr,
|
|||
const FeedsFetchesManager::MLValueCopyInfo& copy_info,
|
||||
const OrtValue& source_mlvalue,
|
||||
OrtValue& target_mlvalue) {
|
||||
if (copy_info.allocation_provider == nullptr){
|
||||
if (copy_info.allocation_provider == nullptr) {
|
||||
target_mlvalue = source_mlvalue;
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
@ -202,18 +202,16 @@ static common::Status CachedCopyInputsAcrossDevices(
|
|||
// Setup fetches for execution. Use any provided fetches directly if the provider matches.
|
||||
// If the provider doesn't match, we don't know what device the execution output may be on, so can't assume the output
|
||||
// can be returned to the user directly.
|
||||
// TODO: We should be able to use the allocation plan to know which device an output will be on.
|
||||
static common::Status SetupFetchesForExecute(const SessionState& session_state,
|
||||
const std::vector<std::string>& output_names,
|
||||
std::vector<OrtValue>& fetches, std::vector<OrtValue>& new_fetches,
|
||||
std::vector<bool>* copy_to_new_fetches_cached_values) {
|
||||
ORT_ENFORCE(new_fetches.empty());
|
||||
|
||||
const auto& execution_providers = session_state.GetExecutionProviders();
|
||||
auto num_outputs = output_names.size();
|
||||
|
||||
new_fetches.resize(num_outputs);
|
||||
|
||||
const auto& name_to_id = session_state.GetOrtValueNameIdxMap();
|
||||
const auto* exec_plan = session_state.GetExecutionPlan();
|
||||
// track which fetches can be copied to new_fetches and used directly in the execution.
|
||||
std::vector<bool> local_can_copy_flags(num_outputs, false);
|
||||
|
||||
|
|
@ -254,16 +252,12 @@ static common::Status SetupFetchesForExecute(const SessionState& session_state,
|
|||
continue;
|
||||
}
|
||||
|
||||
const auto& node_provider_type = node.GetExecutionProviderType();
|
||||
const auto& provided_tensor = provided_mlvalue.Get<Tensor>();
|
||||
const auto& provided_tensor_loc = provided_tensor.Location();
|
||||
const auto* tensor_provider = execution_providers.Get(provided_tensor_loc);
|
||||
if (!tensor_provider) {
|
||||
tensor_provider = execution_providers.Get(onnxruntime::kCpuExecutionProvider);
|
||||
}
|
||||
int arg_index;
|
||||
ORT_RETURN_IF_ERROR(name_to_id.GetIdx(arg->Name(), arg_index));
|
||||
const auto& planned_device = exec_plan->GetLocation(arg_index).device;
|
||||
const auto& provided_tensor_device = provided_mlvalue.Get<Tensor>().Location().device;
|
||||
|
||||
auto tensor_provider_type = tensor_provider->Type();
|
||||
if (node_provider_type == tensor_provider_type) {
|
||||
if (planned_device == provided_tensor_device) {
|
||||
new_fetches[idx] = fetches[idx];
|
||||
local_can_copy_flags[idx] = true;
|
||||
continue;
|
||||
|
|
|
|||
Loading…
Reference in a new issue