update set fetches for execution with allocation plan. (#1668)

This commit is contained in:
Ke Zhang 2019-08-21 19:58:05 -07:00 committed by GitHub
parent 6f70a78e1f
commit b53f40a886
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -71,7 +71,7 @@ static Status CopyMLValue(const DataTransferManager& data_transfer_mgr,
const FeedsFetchesManager::MLValueCopyInfo& copy_info,
const OrtValue& source_mlvalue,
OrtValue& target_mlvalue) {
if (copy_info.allocation_provider == nullptr){
if (copy_info.allocation_provider == nullptr) {
target_mlvalue = source_mlvalue;
return Status::OK();
}
@ -202,18 +202,16 @@ static common::Status CachedCopyInputsAcrossDevices(
// Setup fetches for execution. Use any provided fetches directly if the provider matches.
// If the provider doesn't match, we don't know what device the execution output may be on, so can't assume the output
// can be returned to the user directly.
// TODO: We should be able to use the allocation plan to know which device an output will be on.
static common::Status SetupFetchesForExecute(const SessionState& session_state,
const std::vector<std::string>& output_names,
std::vector<OrtValue>& fetches, std::vector<OrtValue>& new_fetches,
std::vector<bool>* copy_to_new_fetches_cached_values) {
ORT_ENFORCE(new_fetches.empty());
const auto& execution_providers = session_state.GetExecutionProviders();
auto num_outputs = output_names.size();
new_fetches.resize(num_outputs);
const auto& name_to_id = session_state.GetOrtValueNameIdxMap();
const auto* exec_plan = session_state.GetExecutionPlan();
// track which fetches can be copied to new_fetches and used directly in the execution.
std::vector<bool> local_can_copy_flags(num_outputs, false);
@ -254,16 +252,12 @@ static common::Status SetupFetchesForExecute(const SessionState& session_state,
continue;
}
const auto& node_provider_type = node.GetExecutionProviderType();
const auto& provided_tensor = provided_mlvalue.Get<Tensor>();
const auto& provided_tensor_loc = provided_tensor.Location();
const auto* tensor_provider = execution_providers.Get(provided_tensor_loc);
if (!tensor_provider) {
tensor_provider = execution_providers.Get(onnxruntime::kCpuExecutionProvider);
}
int arg_index;
ORT_RETURN_IF_ERROR(name_to_id.GetIdx(arg->Name(), arg_index));
const auto& planned_device = exec_plan->GetLocation(arg_index).device;
const auto& provided_tensor_device = provided_mlvalue.Get<Tensor>().Location().device;
auto tensor_provider_type = tensor_provider->Type();
if (node_provider_type == tensor_provider_type) {
if (planned_device == provided_tensor_device) {
new_fetches[idx] = fetches[idx];
local_can_copy_flags[idx] = true;
continue;