Update to match ORT signature

This commit is contained in:
Jeff 2020-04-17 15:04:58 -07:00
parent c925156fd2
commit 027b0cb3f3
2 changed files with 9 additions and 10 deletions

View file

@ -540,7 +540,7 @@ namespace Dml
}
Status ExecutionProviderImpl::CopyTensors(const onnxruntime::Tensor** src, onnxruntime::Tensor** dst, uint32_t count) const
Status ExecutionProviderImpl::CopyTensors(const std::vector<onnxruntime::IDataTransfer::SrcDstPair>& src_dst_pairs) const
{
// Source and destination for batched GPU -> CPU copies
std::vector<ID3D12Resource*> srcDatas;
@ -550,24 +550,24 @@ namespace Dml
assert(!m_closed);
auto provider = const_cast<ExecutionProviderImpl*>(this);
for (uint32_t i = 0; i < count; ++i)
for (uint32_t i = 0; i < src_dst_pairs.size(); ++i)
{
// This batching implementation only handles GPU -> CPU copies. Other copies do not require synchronization
// and are batched across multiple calls to CopyTensor.
if (!IsGpuTensor(*src[i]) || IsGpuTensor(*dst[i]))
if (!IsGpuTensor(src_dst_pairs[i].src) || IsGpuTensor(src_dst_pairs[i].dst))
{
ORT_RETURN_IF_ERROR(CopyTensor(*src[i], *dst[i]));
ORT_RETURN_IF_ERROR(CopyTensor(src_dst_pairs[i].src, src_dst_pairs[i].dst));
continue;
}
TensorWrapper srcWrapper = TensorWrapper(
const_cast<onnxruntime::Tensor*>(src[i]),
const_cast<onnxruntime::Tensor*>(&src_dst_pairs[i].src.get()),
false,
provider,
true);
TensorWrapper dstWrapper = TensorWrapper(
dst[i],
&src_dst_pairs[i].dst.get(),
true,
provider,
true);

View file

@ -92,7 +92,7 @@ namespace Dml
uint32_t GetSuppportedDeviceDataTypeMask() const;
onnxruntime::common::Status CopyTensor(const onnxruntime::Tensor& src, onnxruntime::Tensor& dst) const;
onnxruntime::common::Status CopyTensors(const onnxruntime::Tensor** src, onnxruntime::Tensor** dst, uint32_t count) const;
onnxruntime::common::Status CopyTensors(const std::vector<onnxruntime::IDataTransfer::SrcDstPair>& src_dst_pairs) const;
onnxruntime::common::Status WaitForGpuCompletion();
// IWinmlExecutionProvider methods
@ -200,10 +200,9 @@ namespace Dml
return m_impl->CopyTensor(src, dst);
}
onnxruntime::common::Status CopyTensors(const onnxruntime::Tensor** src, onnxruntime::Tensor** dst, uint32_t count, int exec_queue_id) const
onnxruntime::common::Status CopyTensors(const std::vector<onnxruntime::IDataTransfer::SrcDstPair>& src_dst_pairs) const
{
assert(exec_queue_id == 0);
return m_impl->CopyTensors(src, dst, count);
return m_impl->CopyTensors(src_dst_pairs);
}
bool CanCopy(const OrtDevice& srcDevice, const OrtDevice& dstDevice) const final