mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-24 22:17:32 +00:00
Update to match ORT signature
This commit is contained in:
parent
c925156fd2
commit
027b0cb3f3
2 changed files with 9 additions and 10 deletions
|
|
@ -540,7 +540,7 @@ namespace Dml
|
|||
}
|
||||
|
||||
|
||||
Status ExecutionProviderImpl::CopyTensors(const onnxruntime::Tensor** src, onnxruntime::Tensor** dst, uint32_t count) const
|
||||
Status ExecutionProviderImpl::CopyTensors(const std::vector<onnxruntime::IDataTransfer::SrcDstPair>& src_dst_pairs) const
|
||||
{
|
||||
// Source and destination for batched GPU -> CPU copies
|
||||
std::vector<ID3D12Resource*> srcDatas;
|
||||
|
|
@ -550,24 +550,24 @@ namespace Dml
|
|||
assert(!m_closed);
|
||||
auto provider = const_cast<ExecutionProviderImpl*>(this);
|
||||
|
||||
for (uint32_t i = 0; i < count; ++i)
|
||||
for (uint32_t i = 0; i < src_dst_pairs.size(); ++i)
|
||||
{
|
||||
// This batching implementation only handles GPU -> CPU copies. Other copies do not require synchronization
|
||||
// and are batched across multiple calls to CopyTensor.
|
||||
if (!IsGpuTensor(*src[i]) || IsGpuTensor(*dst[i]))
|
||||
if (!IsGpuTensor(src_dst_pairs[i].src) || IsGpuTensor(src_dst_pairs[i].dst))
|
||||
{
|
||||
ORT_RETURN_IF_ERROR(CopyTensor(*src[i], *dst[i]));
|
||||
ORT_RETURN_IF_ERROR(CopyTensor(src_dst_pairs[i].src, src_dst_pairs[i].dst));
|
||||
continue;
|
||||
}
|
||||
|
||||
TensorWrapper srcWrapper = TensorWrapper(
|
||||
const_cast<onnxruntime::Tensor*>(src[i]),
|
||||
const_cast<onnxruntime::Tensor*>(&src_dst_pairs[i].src.get()),
|
||||
false,
|
||||
provider,
|
||||
true);
|
||||
|
||||
TensorWrapper dstWrapper = TensorWrapper(
|
||||
dst[i],
|
||||
&src_dst_pairs[i].dst.get(),
|
||||
true,
|
||||
provider,
|
||||
true);
|
||||
|
|
|
|||
|
|
@ -92,7 +92,7 @@ namespace Dml
|
|||
uint32_t GetSuppportedDeviceDataTypeMask() const;
|
||||
|
||||
onnxruntime::common::Status CopyTensor(const onnxruntime::Tensor& src, onnxruntime::Tensor& dst) const;
|
||||
onnxruntime::common::Status CopyTensors(const onnxruntime::Tensor** src, onnxruntime::Tensor** dst, uint32_t count) const;
|
||||
onnxruntime::common::Status CopyTensors(const std::vector<onnxruntime::IDataTransfer::SrcDstPair>& src_dst_pairs) const;
|
||||
onnxruntime::common::Status WaitForGpuCompletion();
|
||||
|
||||
// IWinmlExecutionProvider methods
|
||||
|
|
@ -200,10 +200,9 @@ namespace Dml
|
|||
return m_impl->CopyTensor(src, dst);
|
||||
}
|
||||
|
||||
onnxruntime::common::Status CopyTensors(const onnxruntime::Tensor** src, onnxruntime::Tensor** dst, uint32_t count, int exec_queue_id) const
|
||||
onnxruntime::common::Status CopyTensors(const std::vector<onnxruntime::IDataTransfer::SrcDstPair>& src_dst_pairs) const
|
||||
{
|
||||
assert(exec_queue_id == 0);
|
||||
return m_impl->CopyTensors(src, dst, count);
|
||||
return m_impl->CopyTensors(src_dst_pairs);
|
||||
}
|
||||
|
||||
bool CanCopy(const OrtDevice& srcDevice, const OrtDevice& dstDevice) const final
|
||||
|
|
|
|||
Loading…
Reference in a new issue