Add ability to batch device copy for graph inputs and outputs. (#3580)

* Add ability to batch device copy for graph inputs and outputs.
This commit is contained in:
Scott McKay 2020-04-19 17:51:07 +10:00 committed by GitHub
parent ea62b3435a
commit 7d5348f87e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 97 additions and 15 deletions

View file

@ -9,11 +9,11 @@ common::Status IDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const {
return CopyTensor(src, dst, 0);
}
common::Status IDataTransfer::CopyTensors(const Tensor* src, Tensor* dst, int size) const {
ORT_ENFORCE(nullptr != src && nullptr != dst);
for (int i = 0; i < size; ++i) {
ORT_RETURN_IF_ERROR(CopyTensor(src[i], dst[i], 0));
common::Status IDataTransfer::CopyTensors(const std::vector<IDataTransfer::SrcDstPair>& src_dst_pairs) const {
for (const auto& pair : src_dst_pairs) {
ORT_RETURN_IF_ERROR(CopyTensor(pair.src, pair.dst, pair.exec_queue_id));
}
return Status::OK();
}

View file

@ -17,7 +17,15 @@ class IDataTransfer {
virtual common::Status CopyTensor(const Tensor& src, Tensor& dst) const;
virtual common::Status CopyTensor(const Tensor& src, Tensor& dst, int exec_queue_id) const = 0;
virtual common::Status CopyTensors(const Tensor* src, Tensor* dst, int size) const;
struct SrcDstPair {
std::reference_wrapper<const Tensor> src;
std::reference_wrapper<Tensor> dst;
int exec_queue_id;
};
// batched copy. default implementation copies each entry sequentially, and returns on first failure.
virtual common::Status CopyTensors(const std::vector<SrcDstPair>& src_dst_pairs) const;
};
class CPUDataTransfer : public IDataTransfer {

View file

@ -25,7 +25,6 @@ const IDataTransfer* DataTransferManager::GetDataTransfer(const OrtDevice& src_d
return nullptr;
}
Status DataTransferManager::CopyTensor(const Tensor& src, Tensor& dst) const {
return CopyTensor(src, dst, 0);
}
@ -51,4 +50,55 @@ Status DataTransferManager::CopyTensor(const Tensor& src, Tensor& dst, int exec_
dst.Location().device.ToString());
}
common::Status DataTransferManager::CopyTensors(const std::vector<IDataTransfer::SrcDstPair>& src_dst_pairs) const {
if (src_dst_pairs.empty())
return Status::OK();
const auto& first_pair = src_dst_pairs.front();
const OrtDevice& src_device = first_pair.src.get().Location().device;
const OrtDevice& dst_device = first_pair.dst.get().Location().device;
bool all_same = std::all_of(src_dst_pairs.cbegin() + 1, src_dst_pairs.cend(),
[&src_device, &dst_device](const IDataTransfer::SrcDstPair& pair) {
return pair.src.get().Location().device == src_device &&
pair.dst.get().Location().device == dst_device;
});
IDataTransfer* first_dt = nullptr;
for (auto& data_transfer : datatransfers_) {
if (data_transfer->CanCopy(src_device, dst_device)) {
first_dt = data_transfer.get();
break;
}
}
if (first_dt == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME,
FAIL,
"There's no data transfer registered for copying tensors from ",
src_device.ToString(),
" to ",
dst_device.ToString());
}
// all copies are between the same devices so we can do them all at once
if (all_same) {
return first_dt->CopyTensors(src_dst_pairs);
}
// there are a mix of devices requiring copies. we don't expect this to happen, so just iterate the pairs
// copying one at a time. if this becomes expected we could create a list for each IDataTransfer instance so we
// batch as much as possible.
// copy the first one as we already did the IDataTransfer lookup
ORT_RETURN_IF_ERROR(first_dt->CopyTensor(first_pair.src.get(), first_pair.dst.get(), first_pair.exec_queue_id));
for (auto cur_pair = src_dst_pairs.cbegin() + 1, end_pair = src_dst_pairs.cend(); cur_pair != end_pair; ++cur_pair) {
ORT_RETURN_IF_ERROR(CopyTensor(cur_pair->src, cur_pair->dst, cur_pair->exec_queue_id));
}
return Status::OK();
}
} // namespace onnxruntime

View file

@ -22,6 +22,7 @@ class DataTransferManager {
common::Status CopyTensor(const Tensor& src, Tensor& dst) const;
common::Status CopyTensor(const Tensor& src, Tensor& dst, int exec_queue_id) const;
common::Status CopyTensors(const std::vector<IDataTransfer::SrcDstPair>& src_dst_pairs) const;
private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(DataTransferManager);

View file

@ -5,7 +5,6 @@
#include <iomanip>
#include "core/graph/graph_viewer.h"
#include "core/framework/data_transfer_manager.h"
#include "core/framework/execution_frame.h"
@ -140,10 +139,16 @@ const std::string& GetNodeInputProviderType(const SessionState::NodeInfo& info)
return required_provider_type;
}
static Status CopyMLValue(const DataTransferManager& data_transfer_mgr,
const MLValueCopyInfo& copy_info,
const OrtValue& source_mlvalue,
OrtValue& target_mlvalue) {
// Copy MLValue. Uses DataTransferManager for device copy if necessary. If copy_pairs is provided,
// src/dst pairs that need a device copy are added to copy_pairs so copying can be batches by the DataTransferManager
// implementation for performance reasons.
static Status BatchOrCopyMLValue(
const DataTransferManager& data_transfer_mgr,
const MLValueCopyInfo& copy_info,
const OrtValue& source_mlvalue,
OrtValue& target_mlvalue,
std::vector<IDataTransfer::SrcDstPair>* copy_pairs = nullptr) {
// same device so direct copy
if (copy_info.source_device == copy_info.target_device) {
target_mlvalue = source_mlvalue;
return Status::OK();
@ -164,7 +169,11 @@ static Status CopyMLValue(const DataTransferManager& data_transfer_mgr,
Tensor* p_output_tensor = target_mlvalue.GetMutable<Tensor>();
ORT_RETURN_IF_ERROR(data_transfer_mgr.CopyTensor(source_tensor, *p_output_tensor));
if (copy_pairs != nullptr) {
copy_pairs->push_back({source_tensor, *p_output_tensor, 0});
} else {
ORT_RETURN_IF_ERROR(data_transfer_mgr.CopyTensor(source_tensor, *p_output_tensor));
}
return Status::OK();
}
@ -385,9 +394,16 @@ static common::Status CopyInputsAcrossDevices(const std::vector<OrtValue>& orig_
ORT_ENFORCE(copy_info.size() == num_feeds);
new_feeds.resize(num_feeds);
std::vector<IDataTransfer::SrcDstPair> batched_data_transfers;
batched_data_transfers.reserve(num_feeds);
for (size_t idx = 0; idx < num_feeds; ++idx) {
ORT_RETURN_IF_ERROR(CopyMLValue(data_transfer_mgr, copy_info[idx], orig_feeds[idx], new_feeds[idx]));
ORT_RETURN_IF_ERROR(BatchOrCopyMLValue(data_transfer_mgr, copy_info[idx], orig_feeds[idx], new_feeds[idx],
&batched_data_transfers));
}
if (!batched_data_transfers.empty()) {
ORT_RETURN_IF_ERROR(data_transfer_mgr.CopyTensors(batched_data_transfers));
}
return Status::OK();
@ -408,7 +424,7 @@ common::Status CopyOneInputAcrossDevices(const SessionState& session_state, cons
ORT_RETURN_IF_ERROR(CalculateStaticCopyInfoForFeed(session_state, input_name, copy_info));
copy_info.source_device = orig_mlvalue.Get<Tensor>().Location().device;
return CopyMLValue(session_state.GetDataTransferMgr(), copy_info, orig_mlvalue, new_mlvalue);
return BatchOrCopyMLValue(session_state.GetDataTransferMgr(), copy_info, orig_mlvalue, new_mlvalue);
}
static common::Status CopyOutputsAcrossDevices(const SessionState& session_state,
@ -419,9 +435,16 @@ static common::Status CopyOutputsAcrossDevices(const SessionState& session_state
user_fetches.resize(num_outputs);
const auto& data_transfer_mgr = session_state.GetDataTransferMgr();
std::vector<IDataTransfer::SrcDstPair> batched_data_transfers;
batched_data_transfers.reserve(num_outputs);
for (size_t idx = 0; idx < num_outputs; ++idx) {
ORT_RETURN_IF_ERROR(CopyMLValue(data_transfer_mgr, copy_info[idx], fetches[idx], user_fetches[idx]));
ORT_RETURN_IF_ERROR(BatchOrCopyMLValue(data_transfer_mgr, copy_info[idx], fetches[idx], user_fetches[idx],
&batched_data_transfers));
}
if (!batched_data_transfers.empty()) {
ORT_RETURN_IF_ERROR(data_transfer_mgr.CopyTensors(batched_data_transfers));
}
return Status::OK();