From 027b0cb3f3da2b250e6a9ef3de64bec6ccc486d6 Mon Sep 17 00:00:00 2001 From: Jeff Date: Fri, 17 Apr 2020 15:04:58 -0700 Subject: [PATCH] Update to match ORT signature --- .../DmlExecutionProvider/src/ExecutionProvider.cpp | 12 ++++++------ .../dml/DmlExecutionProvider/src/ExecutionProvider.h | 7 +++---- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index f43487f497..f01f9de534 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -540,7 +540,7 @@ namespace Dml } - Status ExecutionProviderImpl::CopyTensors(const onnxruntime::Tensor** src, onnxruntime::Tensor** dst, uint32_t count) const + Status ExecutionProviderImpl::CopyTensors(const std::vector& src_dst_pairs) const { // Source and destination for batched GPU -> CPU copies std::vector srcDatas; @@ -550,24 +550,24 @@ namespace Dml assert(!m_closed); auto provider = const_cast(this); - for (uint32_t i = 0; i < count; ++i) + for (uint32_t i = 0; i < src_dst_pairs.size(); ++i) { // This batching implementation only handles GPU -> CPU copies. Other copies do not require synchronization // and are batched across multiple calls to CopyTensor. - if (!IsGpuTensor(*src[i]) || IsGpuTensor(*dst[i])) + if (!IsGpuTensor(src_dst_pairs[i].src) || IsGpuTensor(src_dst_pairs[i].dst)) { - ORT_RETURN_IF_ERROR(CopyTensor(*src[i], *dst[i])); + ORT_RETURN_IF_ERROR(CopyTensor(src_dst_pairs[i].src, src_dst_pairs[i].dst)); continue; } TensorWrapper srcWrapper = TensorWrapper( - const_cast(src[i]), + const_cast(&src_dst_pairs[i].src.get()), false, provider, true); TensorWrapper dstWrapper = TensorWrapper( - dst[i], + &src_dst_pairs[i].dst.get(), true, provider, true); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h index 48798ca2c8..8e85171e54 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h @@ -92,7 +92,7 @@ namespace Dml uint32_t GetSuppportedDeviceDataTypeMask() const; onnxruntime::common::Status CopyTensor(const onnxruntime::Tensor& src, onnxruntime::Tensor& dst) const; - onnxruntime::common::Status CopyTensors(const onnxruntime::Tensor** src, onnxruntime::Tensor** dst, uint32_t count) const; + onnxruntime::common::Status CopyTensors(const std::vector& src_dst_pairs) const; onnxruntime::common::Status WaitForGpuCompletion(); // IWinmlExecutionProvider methods @@ -200,10 +200,9 @@ namespace Dml return m_impl->CopyTensor(src, dst); } - onnxruntime::common::Status CopyTensors(const onnxruntime::Tensor** src, onnxruntime::Tensor** dst, uint32_t count, int exec_queue_id) const + onnxruntime::common::Status CopyTensors(const std::vector& src_dst_pairs) const { - assert(exec_queue_id == 0); - return m_impl->CopyTensors(src, dst, count); + return m_impl->CopyTensors(src_dst_pairs); } bool CanCopy(const OrtDevice& srcDevice, const OrtDevice& dstDevice) const final