From f50fa46fe09819734550a7e8725db999210e2b97 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 25 Sep 2023 12:21:20 -0700 Subject: [PATCH] [JSEP] allow DataTransfer to deal with zero sized input (#17661) ### Description allow DataTransfer to deal with zero sized input. This is a standalone fix for zero-sized tensor handling for JSEP DataTransfer. There are other components in JSEP not supporting zero-sized tensors need to be fixed. --- .../core/providers/js/data_transfer.cc | 30 ++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/onnxruntime/core/providers/js/data_transfer.cc b/onnxruntime/core/providers/js/data_transfer.cc index c62362d908..ebea041b80 100644 --- a/onnxruntime/core/providers/js/data_transfer.cc +++ b/onnxruntime/core/providers/js/data_transfer.cc @@ -20,23 +20,25 @@ bool DataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_dev common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { size_t bytes = src.SizeInBytes(); - const void* src_data = src.DataRaw(); - void* dst_data = dst.MutableDataRaw(); + if (bytes > 0) { + const void* src_data = src.DataRaw(); + void* dst_data = dst.MutableDataRaw(); - auto& src_device = src.Location().device; - auto& dst_device = dst.Location().device; + auto& src_device = src.Location().device; + auto& dst_device = dst.Location().device; - if (dst_device.Type() == OrtDevice::GPU) { - if (src_device.Type() == OrtDevice::GPU) { - // copy from GPU to GPU - EM_ASM({ Module.jsepCopy($0, $1, $2, true); }, src_data, dst_data, bytes); - } else { - // copy from CPU to GPU - EM_ASM({ Module.jsepCopy($0, $1, $2); }, src_data, dst_data, bytes); + if (dst_device.Type() == OrtDevice::GPU) { + if (src_device.Type() == OrtDevice::GPU) { + // copy from GPU to GPU + EM_ASM({ Module.jsepCopy($0, $1, $2, true); }, src_data, dst_data, bytes); + } else { + // copy from CPU to GPU + EM_ASM({ Module.jsepCopy($0, $1, $2); }, src_data, dst_data, bytes); + } + } else /* if (src_device.Type() == OrtDevice::GPU) */ { + // copy from GPU to CPU + jsepDownload(src_data, dst_data, bytes); } - } else /* if (src_device.Type() == OrtDevice::GPU) */ { - // copy from GPU to CPU - jsepDownload(src_data, dst_data, bytes); } return Status::OK();