mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
Add ability to batch device copy for graph inputs and outputs. (#3580)
* Add ability to batch device copy for graph inputs and outputs.
This commit is contained in:
parent
ea62b3435a
commit
7d5348f87e
5 changed files with 97 additions and 15 deletions
|
|
@ -9,11 +9,11 @@ common::Status IDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const {
|
|||
return CopyTensor(src, dst, 0);
|
||||
}
|
||||
|
||||
common::Status IDataTransfer::CopyTensors(const Tensor* src, Tensor* dst, int size) const {
|
||||
ORT_ENFORCE(nullptr != src && nullptr != dst);
|
||||
for (int i = 0; i < size; ++i) {
|
||||
ORT_RETURN_IF_ERROR(CopyTensor(src[i], dst[i], 0));
|
||||
common::Status IDataTransfer::CopyTensors(const std::vector<IDataTransfer::SrcDstPair>& src_dst_pairs) const {
|
||||
for (const auto& pair : src_dst_pairs) {
|
||||
ORT_RETURN_IF_ERROR(CopyTensor(pair.src, pair.dst, pair.exec_queue_id));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -17,7 +17,15 @@ class IDataTransfer {
|
|||
|
||||
virtual common::Status CopyTensor(const Tensor& src, Tensor& dst) const;
|
||||
virtual common::Status CopyTensor(const Tensor& src, Tensor& dst, int exec_queue_id) const = 0;
|
||||
virtual common::Status CopyTensors(const Tensor* src, Tensor* dst, int size) const;
|
||||
|
||||
struct SrcDstPair {
|
||||
std::reference_wrapper<const Tensor> src;
|
||||
std::reference_wrapper<Tensor> dst;
|
||||
int exec_queue_id;
|
||||
};
|
||||
|
||||
// batched copy. default implementation copies each entry sequentially, and returns on first failure.
|
||||
virtual common::Status CopyTensors(const std::vector<SrcDstPair>& src_dst_pairs) const;
|
||||
};
|
||||
|
||||
class CPUDataTransfer : public IDataTransfer {
|
||||
|
|
|
|||
|
|
@ -25,7 +25,6 @@ const IDataTransfer* DataTransferManager::GetDataTransfer(const OrtDevice& src_d
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
|
||||
Status DataTransferManager::CopyTensor(const Tensor& src, Tensor& dst) const {
|
||||
return CopyTensor(src, dst, 0);
|
||||
}
|
||||
|
|
@ -51,4 +50,55 @@ Status DataTransferManager::CopyTensor(const Tensor& src, Tensor& dst, int exec_
|
|||
dst.Location().device.ToString());
|
||||
}
|
||||
|
||||
common::Status DataTransferManager::CopyTensors(const std::vector<IDataTransfer::SrcDstPair>& src_dst_pairs) const {
|
||||
if (src_dst_pairs.empty())
|
||||
return Status::OK();
|
||||
|
||||
const auto& first_pair = src_dst_pairs.front();
|
||||
const OrtDevice& src_device = first_pair.src.get().Location().device;
|
||||
const OrtDevice& dst_device = first_pair.dst.get().Location().device;
|
||||
|
||||
bool all_same = std::all_of(src_dst_pairs.cbegin() + 1, src_dst_pairs.cend(),
|
||||
[&src_device, &dst_device](const IDataTransfer::SrcDstPair& pair) {
|
||||
return pair.src.get().Location().device == src_device &&
|
||||
pair.dst.get().Location().device == dst_device;
|
||||
});
|
||||
|
||||
IDataTransfer* first_dt = nullptr;
|
||||
|
||||
for (auto& data_transfer : datatransfers_) {
|
||||
if (data_transfer->CanCopy(src_device, dst_device)) {
|
||||
first_dt = data_transfer.get();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (first_dt == nullptr) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME,
|
||||
FAIL,
|
||||
"There's no data transfer registered for copying tensors from ",
|
||||
src_device.ToString(),
|
||||
" to ",
|
||||
dst_device.ToString());
|
||||
}
|
||||
|
||||
// all copies are between the same devices so we can do them all at once
|
||||
if (all_same) {
|
||||
return first_dt->CopyTensors(src_dst_pairs);
|
||||
}
|
||||
|
||||
// there are a mix of devices requiring copies. we don't expect this to happen, so just iterate the pairs
|
||||
// copying one at a time. if this becomes expected we could create a list for each IDataTransfer instance so we
|
||||
// batch as much as possible.
|
||||
|
||||
// copy the first one as we already did the IDataTransfer lookup
|
||||
ORT_RETURN_IF_ERROR(first_dt->CopyTensor(first_pair.src.get(), first_pair.dst.get(), first_pair.exec_queue_id));
|
||||
|
||||
for (auto cur_pair = src_dst_pairs.cbegin() + 1, end_pair = src_dst_pairs.cend(); cur_pair != end_pair; ++cur_pair) {
|
||||
ORT_RETURN_IF_ERROR(CopyTensor(cur_pair->src, cur_pair->dst, cur_pair->exec_queue_id));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ class DataTransferManager {
|
|||
|
||||
common::Status CopyTensor(const Tensor& src, Tensor& dst) const;
|
||||
common::Status CopyTensor(const Tensor& src, Tensor& dst, int exec_queue_id) const;
|
||||
common::Status CopyTensors(const std::vector<IDataTransfer::SrcDstPair>& src_dst_pairs) const;
|
||||
|
||||
private:
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(DataTransferManager);
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
|
||||
#include <iomanip>
|
||||
|
||||
|
||||
#include "core/graph/graph_viewer.h"
|
||||
#include "core/framework/data_transfer_manager.h"
|
||||
#include "core/framework/execution_frame.h"
|
||||
|
|
@ -140,10 +139,16 @@ const std::string& GetNodeInputProviderType(const SessionState::NodeInfo& info)
|
|||
return required_provider_type;
|
||||
}
|
||||
|
||||
static Status CopyMLValue(const DataTransferManager& data_transfer_mgr,
|
||||
const MLValueCopyInfo& copy_info,
|
||||
const OrtValue& source_mlvalue,
|
||||
OrtValue& target_mlvalue) {
|
||||
// Copy MLValue. Uses DataTransferManager for device copy if necessary. If copy_pairs is provided,
|
||||
// src/dst pairs that need a device copy are added to copy_pairs so copying can be batches by the DataTransferManager
|
||||
// implementation for performance reasons.
|
||||
static Status BatchOrCopyMLValue(
|
||||
const DataTransferManager& data_transfer_mgr,
|
||||
const MLValueCopyInfo& copy_info,
|
||||
const OrtValue& source_mlvalue,
|
||||
OrtValue& target_mlvalue,
|
||||
std::vector<IDataTransfer::SrcDstPair>* copy_pairs = nullptr) {
|
||||
// same device so direct copy
|
||||
if (copy_info.source_device == copy_info.target_device) {
|
||||
target_mlvalue = source_mlvalue;
|
||||
return Status::OK();
|
||||
|
|
@ -164,7 +169,11 @@ static Status CopyMLValue(const DataTransferManager& data_transfer_mgr,
|
|||
|
||||
Tensor* p_output_tensor = target_mlvalue.GetMutable<Tensor>();
|
||||
|
||||
ORT_RETURN_IF_ERROR(data_transfer_mgr.CopyTensor(source_tensor, *p_output_tensor));
|
||||
if (copy_pairs != nullptr) {
|
||||
copy_pairs->push_back({source_tensor, *p_output_tensor, 0});
|
||||
} else {
|
||||
ORT_RETURN_IF_ERROR(data_transfer_mgr.CopyTensor(source_tensor, *p_output_tensor));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
@ -385,9 +394,16 @@ static common::Status CopyInputsAcrossDevices(const std::vector<OrtValue>& orig_
|
|||
ORT_ENFORCE(copy_info.size() == num_feeds);
|
||||
|
||||
new_feeds.resize(num_feeds);
|
||||
std::vector<IDataTransfer::SrcDstPair> batched_data_transfers;
|
||||
batched_data_transfers.reserve(num_feeds);
|
||||
|
||||
for (size_t idx = 0; idx < num_feeds; ++idx) {
|
||||
ORT_RETURN_IF_ERROR(CopyMLValue(data_transfer_mgr, copy_info[idx], orig_feeds[idx], new_feeds[idx]));
|
||||
ORT_RETURN_IF_ERROR(BatchOrCopyMLValue(data_transfer_mgr, copy_info[idx], orig_feeds[idx], new_feeds[idx],
|
||||
&batched_data_transfers));
|
||||
}
|
||||
|
||||
if (!batched_data_transfers.empty()) {
|
||||
ORT_RETURN_IF_ERROR(data_transfer_mgr.CopyTensors(batched_data_transfers));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
@ -408,7 +424,7 @@ common::Status CopyOneInputAcrossDevices(const SessionState& session_state, cons
|
|||
ORT_RETURN_IF_ERROR(CalculateStaticCopyInfoForFeed(session_state, input_name, copy_info));
|
||||
copy_info.source_device = orig_mlvalue.Get<Tensor>().Location().device;
|
||||
|
||||
return CopyMLValue(session_state.GetDataTransferMgr(), copy_info, orig_mlvalue, new_mlvalue);
|
||||
return BatchOrCopyMLValue(session_state.GetDataTransferMgr(), copy_info, orig_mlvalue, new_mlvalue);
|
||||
}
|
||||
|
||||
static common::Status CopyOutputsAcrossDevices(const SessionState& session_state,
|
||||
|
|
@ -419,9 +435,16 @@ static common::Status CopyOutputsAcrossDevices(const SessionState& session_state
|
|||
user_fetches.resize(num_outputs);
|
||||
|
||||
const auto& data_transfer_mgr = session_state.GetDataTransferMgr();
|
||||
std::vector<IDataTransfer::SrcDstPair> batched_data_transfers;
|
||||
batched_data_transfers.reserve(num_outputs);
|
||||
|
||||
for (size_t idx = 0; idx < num_outputs; ++idx) {
|
||||
ORT_RETURN_IF_ERROR(CopyMLValue(data_transfer_mgr, copy_info[idx], fetches[idx], user_fetches[idx]));
|
||||
ORT_RETURN_IF_ERROR(BatchOrCopyMLValue(data_transfer_mgr, copy_info[idx], fetches[idx], user_fetches[idx],
|
||||
&batched_data_transfers));
|
||||
}
|
||||
|
||||
if (!batched_data_transfers.empty()) {
|
||||
ORT_RETURN_IF_ERROR(data_transfer_mgr.CopyTensors(batched_data_transfers));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
|||
Loading…
Reference in a new issue