diff --git a/onnxruntime/core/framework/data_transfer.cc b/onnxruntime/core/framework/data_transfer.cc index 80099c2e4e..77dbba8b5f 100644 --- a/onnxruntime/core/framework/data_transfer.cc +++ b/onnxruntime/core/framework/data_transfer.cc @@ -9,11 +9,11 @@ common::Status IDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { return CopyTensor(src, dst, 0); } -common::Status IDataTransfer::CopyTensors(const Tensor* src, Tensor* dst, int size) const { - ORT_ENFORCE(nullptr != src && nullptr != dst); - for (int i = 0; i < size; ++i) { - ORT_RETURN_IF_ERROR(CopyTensor(src[i], dst[i], 0)); +common::Status IDataTransfer::CopyTensors(const std::vector& src_dst_pairs) const { + for (const auto& pair : src_dst_pairs) { + ORT_RETURN_IF_ERROR(CopyTensor(pair.src, pair.dst, pair.exec_queue_id)); } + return Status::OK(); } diff --git a/onnxruntime/core/framework/data_transfer.h b/onnxruntime/core/framework/data_transfer.h index 7278f067db..e2093f0abf 100644 --- a/onnxruntime/core/framework/data_transfer.h +++ b/onnxruntime/core/framework/data_transfer.h @@ -17,7 +17,15 @@ class IDataTransfer { virtual common::Status CopyTensor(const Tensor& src, Tensor& dst) const; virtual common::Status CopyTensor(const Tensor& src, Tensor& dst, int exec_queue_id) const = 0; - virtual common::Status CopyTensors(const Tensor* src, Tensor* dst, int size) const; + + struct SrcDstPair { + std::reference_wrapper src; + std::reference_wrapper dst; + int exec_queue_id; + }; + + // batched copy. default implementation copies each entry sequentially, and returns on first failure. + virtual common::Status CopyTensors(const std::vector& src_dst_pairs) const; }; class CPUDataTransfer : public IDataTransfer { diff --git a/onnxruntime/core/framework/data_transfer_manager.cc b/onnxruntime/core/framework/data_transfer_manager.cc index 3f07ac7559..72ff3a19b1 100644 --- a/onnxruntime/core/framework/data_transfer_manager.cc +++ b/onnxruntime/core/framework/data_transfer_manager.cc @@ -25,7 +25,6 @@ const IDataTransfer* DataTransferManager::GetDataTransfer(const OrtDevice& src_d return nullptr; } - Status DataTransferManager::CopyTensor(const Tensor& src, Tensor& dst) const { return CopyTensor(src, dst, 0); } @@ -51,4 +50,55 @@ Status DataTransferManager::CopyTensor(const Tensor& src, Tensor& dst, int exec_ dst.Location().device.ToString()); } +common::Status DataTransferManager::CopyTensors(const std::vector& src_dst_pairs) const { + if (src_dst_pairs.empty()) + return Status::OK(); + + const auto& first_pair = src_dst_pairs.front(); + const OrtDevice& src_device = first_pair.src.get().Location().device; + const OrtDevice& dst_device = first_pair.dst.get().Location().device; + + bool all_same = std::all_of(src_dst_pairs.cbegin() + 1, src_dst_pairs.cend(), + [&src_device, &dst_device](const IDataTransfer::SrcDstPair& pair) { + return pair.src.get().Location().device == src_device && + pair.dst.get().Location().device == dst_device; + }); + + IDataTransfer* first_dt = nullptr; + + for (auto& data_transfer : datatransfers_) { + if (data_transfer->CanCopy(src_device, dst_device)) { + first_dt = data_transfer.get(); + break; + } + } + + if (first_dt == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, + FAIL, + "There's no data transfer registered for copying tensors from ", + src_device.ToString(), + " to ", + dst_device.ToString()); + } + + // all copies are between the same devices so we can do them all at once + if (all_same) { + return first_dt->CopyTensors(src_dst_pairs); + } + + // there are a mix of devices requiring copies. we don't expect this to happen, so just iterate the pairs + // copying one at a time. if this becomes expected we could create a list for each IDataTransfer instance so we + // batch as much as possible. + + // copy the first one as we already did the IDataTransfer lookup + ORT_RETURN_IF_ERROR(first_dt->CopyTensor(first_pair.src.get(), first_pair.dst.get(), first_pair.exec_queue_id)); + + for (auto cur_pair = src_dst_pairs.cbegin() + 1, end_pair = src_dst_pairs.cend(); cur_pair != end_pair; ++cur_pair) { + ORT_RETURN_IF_ERROR(CopyTensor(cur_pair->src, cur_pair->dst, cur_pair->exec_queue_id)); + } + + return Status::OK(); +} + } // namespace onnxruntime diff --git a/onnxruntime/core/framework/data_transfer_manager.h b/onnxruntime/core/framework/data_transfer_manager.h index 35c85a0ba1..c7667d2049 100644 --- a/onnxruntime/core/framework/data_transfer_manager.h +++ b/onnxruntime/core/framework/data_transfer_manager.h @@ -22,6 +22,7 @@ class DataTransferManager { common::Status CopyTensor(const Tensor& src, Tensor& dst) const; common::Status CopyTensor(const Tensor& src, Tensor& dst, int exec_queue_id) const; + common::Status CopyTensors(const std::vector& src_dst_pairs) const; private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(DataTransferManager); diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index 5ffc381c65..9bbf4dd9f7 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -5,7 +5,6 @@ #include - #include "core/graph/graph_viewer.h" #include "core/framework/data_transfer_manager.h" #include "core/framework/execution_frame.h" @@ -140,10 +139,16 @@ const std::string& GetNodeInputProviderType(const SessionState::NodeInfo& info) return required_provider_type; } -static Status CopyMLValue(const DataTransferManager& data_transfer_mgr, - const MLValueCopyInfo& copy_info, - const OrtValue& source_mlvalue, - OrtValue& target_mlvalue) { +// Copy MLValue. Uses DataTransferManager for device copy if necessary. If copy_pairs is provided, +// src/dst pairs that need a device copy are added to copy_pairs so copying can be batches by the DataTransferManager +// implementation for performance reasons. +static Status BatchOrCopyMLValue( + const DataTransferManager& data_transfer_mgr, + const MLValueCopyInfo& copy_info, + const OrtValue& source_mlvalue, + OrtValue& target_mlvalue, + std::vector* copy_pairs = nullptr) { + // same device so direct copy if (copy_info.source_device == copy_info.target_device) { target_mlvalue = source_mlvalue; return Status::OK(); @@ -164,7 +169,11 @@ static Status CopyMLValue(const DataTransferManager& data_transfer_mgr, Tensor* p_output_tensor = target_mlvalue.GetMutable(); - ORT_RETURN_IF_ERROR(data_transfer_mgr.CopyTensor(source_tensor, *p_output_tensor)); + if (copy_pairs != nullptr) { + copy_pairs->push_back({source_tensor, *p_output_tensor, 0}); + } else { + ORT_RETURN_IF_ERROR(data_transfer_mgr.CopyTensor(source_tensor, *p_output_tensor)); + } return Status::OK(); } @@ -385,9 +394,16 @@ static common::Status CopyInputsAcrossDevices(const std::vector& orig_ ORT_ENFORCE(copy_info.size() == num_feeds); new_feeds.resize(num_feeds); + std::vector batched_data_transfers; + batched_data_transfers.reserve(num_feeds); for (size_t idx = 0; idx < num_feeds; ++idx) { - ORT_RETURN_IF_ERROR(CopyMLValue(data_transfer_mgr, copy_info[idx], orig_feeds[idx], new_feeds[idx])); + ORT_RETURN_IF_ERROR(BatchOrCopyMLValue(data_transfer_mgr, copy_info[idx], orig_feeds[idx], new_feeds[idx], + &batched_data_transfers)); + } + + if (!batched_data_transfers.empty()) { + ORT_RETURN_IF_ERROR(data_transfer_mgr.CopyTensors(batched_data_transfers)); } return Status::OK(); @@ -408,7 +424,7 @@ common::Status CopyOneInputAcrossDevices(const SessionState& session_state, cons ORT_RETURN_IF_ERROR(CalculateStaticCopyInfoForFeed(session_state, input_name, copy_info)); copy_info.source_device = orig_mlvalue.Get().Location().device; - return CopyMLValue(session_state.GetDataTransferMgr(), copy_info, orig_mlvalue, new_mlvalue); + return BatchOrCopyMLValue(session_state.GetDataTransferMgr(), copy_info, orig_mlvalue, new_mlvalue); } static common::Status CopyOutputsAcrossDevices(const SessionState& session_state, @@ -419,9 +435,16 @@ static common::Status CopyOutputsAcrossDevices(const SessionState& session_state user_fetches.resize(num_outputs); const auto& data_transfer_mgr = session_state.GetDataTransferMgr(); + std::vector batched_data_transfers; + batched_data_transfers.reserve(num_outputs); for (size_t idx = 0; idx < num_outputs; ++idx) { - ORT_RETURN_IF_ERROR(CopyMLValue(data_transfer_mgr, copy_info[idx], fetches[idx], user_fetches[idx])); + ORT_RETURN_IF_ERROR(BatchOrCopyMLValue(data_transfer_mgr, copy_info[idx], fetches[idx], user_fetches[idx], + &batched_data_transfers)); + } + + if (!batched_data_transfers.empty()) { + ORT_RETURN_IF_ERROR(data_transfer_mgr.CopyTensors(batched_data_transfers)); } return Status::OK();