From 0f6b3ea575a7fea39e9f9cec7b964f7777ea7945 Mon Sep 17 00:00:00 2001 From: Yufeng Li Date: Tue, 14 May 2019 13:35:15 -0700 Subject: [PATCH] Improve the performance of Scatter (#991) * Improve the performance of scatter --- .../core/providers/cpu/tensor/scatter.cc | 64 +++++++++++++------ 1 file changed, 45 insertions(+), 19 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/scatter.cc b/onnxruntime/core/providers/cpu/tensor/scatter.cc index be6e19b80e..3d3c78c1ec 100644 --- a/onnxruntime/core/providers/cpu/tensor/scatter.cc +++ b/onnxruntime/core/providers/cpu/tensor/scatter.cc @@ -30,7 +30,7 @@ ONNX_CPU_OPERATOR_KERNEL( .TypeConstraint("Tind", std::vector{DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), Scatter); -template +template Status CopyScatterData(const Tensor* data_input, const Tensor* indices_input, const Tensor* updates_input, const int64_t axis, Tensor* data_output) { const TensorShape& input_data_shape = data_input->Shape(); @@ -45,12 +45,11 @@ Status CopyScatterData(const Tensor* data_input, const Tensor* indices_input, co } const auto input_elements = input_data_shape.Size(); - const auto element_bytes = data_input->DataType()->Size(); const auto total_input_bytes = data_input->Size(); - const uint8_t* src_base = reinterpret_cast(data_input->DataRaw()); - uint8_t* dst_base = reinterpret_cast(data_output->MutableDataRaw()); - const bool is_string_type = data_input->DataType() == DataTypeImpl::GetType(); + const Tdata* src_base = static_cast(data_input->DataRaw()); + Tdata* dst_base = static_cast(data_output->MutableDataRaw()); + bool is_string_type = data_input->DataType() == DataTypeImpl::GetType(); // We allow runtime to re-use input for output. If input/output Tensor* are the same // we do not copy @@ -110,7 +109,7 @@ Status CopyScatterData(const Tensor* data_input, const Tensor* indices_input, co } } - const uint8_t* update_data = reinterpret_cast(updates_input->DataRaw()); + const Tdata* update_data = static_cast(updates_input->DataRaw()); // For every update we compute the destination offset and copy it there for (int64_t index = 0; index < num_indices;) { const Tin axis_idx = indices_data[index]; @@ -127,16 +126,7 @@ Status CopyScatterData(const Tensor* data_input, const Tensor* indices_input, co } } - const size_t dst_offset_bytes = dst_offset * element_bytes; - assert(dst_offset_bytes < total_input_bytes); - if (is_string_type) { - reinterpret_cast(dst_base)[dst_offset] = - reinterpret_cast(update_data)[index]; - } else { - // Copy an element - auto src_offset_bytes = index * element_bytes; - memcpy(dst_base + dst_offset_bytes, update_data + src_offset_bytes, element_bytes); - } + dst_base[dst_offset] = update_data[index]; if (++index == num_indices) { break; @@ -158,6 +148,38 @@ Status CopyScatterData(const Tensor* data_input, const Tensor* indices_input, co return Status::OK(); } +#define DispatchOnIndexTypeAndTensorType(index_type, tensor_type, retval, function, ...) \ + if (tensor_type == DataTypeImpl::GetType()) \ + retval = function(__VA_ARGS__); \ + else if (tensor_type == DataTypeImpl::GetType()) \ + retval = function(__VA_ARGS__); \ + else if (tensor_type == DataTypeImpl::GetType()) \ + retval = function(__VA_ARGS__); \ + else if (tensor_type == DataTypeImpl::GetType()) \ + retval = function(__VA_ARGS__); \ + else if (tensor_type == DataTypeImpl::GetType()) \ + retval = function(__VA_ARGS__); \ + else if (tensor_type == DataTypeImpl::GetType()) \ + retval = function(__VA_ARGS__); \ + else if (tensor_type == DataTypeImpl::GetType()) \ + retval = function(__VA_ARGS__); \ + else if (tensor_type == DataTypeImpl::GetType()) \ + retval = function(__VA_ARGS__); \ + else if (tensor_type == DataTypeImpl::GetType()) \ + retval = function(__VA_ARGS__); \ + else if (tensor_type == DataTypeImpl::GetType()) \ + retval = function(__VA_ARGS__); \ + else if (tensor_type == DataTypeImpl::GetType()) \ + retval = function(__VA_ARGS__); \ + else if (tensor_type == DataTypeImpl::GetType()) \ + retval = function(__VA_ARGS__); \ + else if (tensor_type == DataTypeImpl::GetType()) \ + retval = function(__VA_ARGS__); \ + else if (tensor_type == DataTypeImpl::GetType()) \ + retval = function(__VA_ARGS__); \ + else \ + ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type) + Status Scatter::Compute(OpKernelContext* context) const { const auto* data_input = context->Input(0); const auto& input_data_shape = data_input->Shape(); @@ -203,12 +225,16 @@ Status Scatter::Compute(OpKernelContext* context) const { auto* data_output = context->Output(0, input_data_shape); MLDataType Tind_type = indices_input->DataType(); + MLDataType Tdata_type = data_input->DataType(); + Status status; if (Tind_type == DataTypeImpl::GetType()) { - return CopyScatterData(data_input, indices_input, updates_input, axis, data_output); + DispatchOnIndexTypeAndTensorType(int32_t, Tdata_type, status, CopyScatterData, data_input, indices_input, updates_input, axis, data_output); } else if (Tind_type == DataTypeImpl::GetType()) { - return CopyScatterData(data_input, indices_input, updates_input, axis, data_output); + DispatchOnIndexTypeAndTensorType(int64_t, Tdata_type, status, CopyScatterData, data_input, indices_input, updates_input, axis, data_output); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Expecting indices to be either int32_t or int64_t"); } - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Expecting indices to be either int32_t or int64_t"); + return status; } } // namespace onnxruntime