diff --git a/onnxruntime/core/providers/cpu/tensor/scatter_nd.cc b/onnxruntime/core/providers/cpu/tensor/scatter_nd.cc index b4287003db..d94bb30fdf 100644 --- a/onnxruntime/core/providers/cpu/tensor/scatter_nd.cc +++ b/onnxruntime/core/providers/cpu/tensor/scatter_nd.cc @@ -3,6 +3,7 @@ #include "scatter_nd.h" #include "core/platform/threadpool.h" +#include "core/providers/cpu/tensor/utils.h" namespace onnxruntime { @@ -11,73 +12,87 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( 11, 12, KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::AllTensorTypes()) - .TypeConstraint("Tind", DataTypeImpl::GetTensorType()), + .TypeConstraint("T", DataTypeImpl::AllTensorTypes()), ScatterND); ONNX_CPU_OPERATOR_KERNEL( ScatterND, 13, KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::AllTensorTypes()) - .TypeConstraint("Tind", DataTypeImpl::GetTensorType()), + .TypeConstraint("T", DataTypeImpl::AllTensorTypes()), ScatterND); -template -Status ScatterNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) const { - const auto* input_tensor = context->Input(0); - const auto* indice_tensor = context->Input(1); - const auto* update_tensor = context->Input(2); - ORT_ENFORCE(input_tensor != nullptr); - ORT_ENFORCE(indice_tensor != nullptr); - ORT_ENFORCE(update_tensor != nullptr); +Status ScatterNDBase::ValidateShapes(const TensorShape& input_shape, + const TensorShape& indice_shape, + const TensorShape& update_shape) { + auto input_rank = input_shape.NumDimensions(); + auto indice_rank = indice_shape.NumDimensions(); + auto update_rank = update_shape.NumDimensions(); - const auto& input_shape = input_tensor->Shape(); - const auto& indice_shape = indice_tensor->Shape(); - const auto& update_shape = update_tensor->Shape(); - if (indice_shape.NumDimensions() == 0 || input_shape.NumDimensions() == 0) { + if (input_rank == 0 || indice_rank == 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "input tensor and indices tensor must has rank larger than 0. ", "input shape: ", input_shape, ", indices shape: ", indice_shape); } - auto indice_rank = indice_shape.NumDimensions(); auto last_indice_dimension = indice_shape[indice_rank - 1]; - if (last_indice_dimension > static_cast(input_shape.NumDimensions())) { + if (last_indice_dimension > static_cast(input_rank)) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "last dimension of indices must not be larger than rank of input tensor"); } bool is_update_shape_invalid = [&]() { - auto update_rank = update_shape.NumDimensions(); - auto input_rank = input_shape.NumDimensions(); - if (update_rank < indice_rank - 1) { + // Validate rank of update tensor + // Per spec, the rank of the update tensor should be: + // (Rank of input tensor) + (Rank of indices tensor) -1 - last_indice_dimension + if (update_rank != (input_rank + indice_rank - 1 - static_cast(last_indice_dimension))) { return true; } - if (update_rank >= indice_rank - 1 && - indice_rank >= 1 && - (indice_shape.Slice(0, indice_rank - 1) != update_shape.Slice(0, indice_rank - 1))) { + + // Validate shape of the update tensor + // Part 1: The shape of the update tensor upto the indices rank - 1 (exclusive) + // should match the shape of the indices tensor upto indices rank - 1 (exclusive) + if (indice_shape.Slice(0, indice_rank - 1) != update_shape.Slice(0, indice_rank - 1)) { return true; } - if ((static_cast(input_rank) > last_indice_dimension) && - (update_rank >= indice_rank - 1) && - (input_shape.Slice(last_indice_dimension) != update_shape.Slice(indice_rank - 1))) { + + // Part 2: The shape of the update tensor after indices rank - 1 (inclusive) + // should match the shape of the input tensor after `last_indice_dimension` + if (input_shape.Slice(last_indice_dimension) != update_shape.Slice(indice_rank - 1)) { return true; } + return false; }(); + if (is_update_shape_invalid) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "updates tensor should have shape equal to indices.shape[:-1] + data.shape[indices.shape[-1]:]. ", "updates shape: ", update_shape, ", indices shape: ", indice_shape, ", data shape: ", input_shape); } + return Status::OK(); +} + +Status ScatterNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) const { + const auto* input_tensor = context->Input(0); + const auto* indice_tensor = context->Input(1); + const auto* update_tensor = context->Input(2); + + const auto& input_shape = input_tensor->Shape(); + const auto& indice_shape = indice_tensor->Shape(); + const auto& update_shape = update_tensor->Shape(); + + ORT_RETURN_IF_ERROR(ValidateShapes(input_shape, indice_shape, update_shape)); + auto output_tensor = context->Output(0, input_shape); const auto* src_base = input_tensor->DataRaw(); auto* dst_base = output_tensor->MutableDataRaw(); const bool is_string_type = input_tensor->IsDataTypeString(); + auto last_indice_dimension = indice_shape[indice_shape.NumDimensions() - 1]; + // Re-use input for output. If input/output Tensor* are the same, do not copy. if (src_base != dst_base) { if (is_string_type) { @@ -92,15 +107,16 @@ Status ScatterNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) co std::vector element_counts(last_indice_dimension, 0LL); // Number of elements for each input dimension + TensorPitches input_strides(input_shape); for (int64_t i = 0; i < last_indice_dimension; ++i) { - element_counts[i] = input_shape.SizeFromDimension(i + 1); + element_counts[i] = input_strides[i]; } int64_t err_indice = 0; p.element_bytes = input_tensor->DataType()->Size(); p.element_to_copy = input_shape.SizeFromDimension(last_indice_dimension); p.bytes_to_copy = p.element_bytes * p.element_to_copy; - auto indice_offset = static_cast(indice_tensor->DataRaw()); + const int64_t* indice_offset = indice_tensor->template Data(); auto offset_count = indice_shape.Size() / last_indice_dimension; // Times to copy p.element_offsets.assign(offset_count, 0LL); @@ -124,12 +140,10 @@ Status ScatterNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) co return err_indice == 0 ? Status::OK() : ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid indice found, indice = ", err_indice); } -template Status ScatterNDBase::PrepareForCompute(OpKernelContext*, Prepare&) const; - Status ScatterND::Compute(OpKernelContext* context) const { Prepare p; concurrency::ThreadPool* tp = context->GetOperatorThreadPool(); - ORT_RETURN_IF_ERROR(PrepareForCompute(context, p)); + ORT_RETURN_IF_ERROR(PrepareForCompute(context, p)); return nullptr == p.input_str_base ? ScatterNumber(p, tp) : ScatterString(p, tp); } diff --git a/onnxruntime/core/providers/cpu/tensor/scatter_nd.h b/onnxruntime/core/providers/cpu/tensor/scatter_nd.h index 9d9a52ea41..308c542722 100644 --- a/onnxruntime/core/providers/cpu/tensor/scatter_nd.h +++ b/onnxruntime/core/providers/cpu/tensor/scatter_nd.h @@ -32,7 +32,11 @@ class ScatterNDBase { element_offsets(0) {} }; // struct Prepare - template + // Shared between the CPU and CUDA implementation + static Status ValidateShapes(const TensorShape& input_shape, + const TensorShape& indice_shape, + const TensorShape& update_shape); + Status PrepareForCompute(OpKernelContext* context, Prepare& p) const; }; // class ScatterNDBase diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 835555d07d..c8cdf274f4 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -759,6 +759,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, int32_t_float_int32_t, OneHot); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, int64_t_MLFloat16_int64_t, OneHot); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, int32_t_MLFloat16_int32_t, OneHot); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, ScatterND); // OpSet 12 class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, Clip); @@ -988,6 +989,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, LRN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, LRN); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Identity); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, ScatterND); template <> KernelCreateInfo BuildKernelCreateInfo() { @@ -1462,6 +1464,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, // OpSet 12 BuildKernelCreateInfo, @@ -1691,6 +1694,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_nd.cc b/onnxruntime/core/providers/cuda/tensor/scatter_nd.cc new file mode 100644 index 0000000000..bec33de6dd --- /dev/null +++ b/onnxruntime/core/providers/cuda/tensor/scatter_nd.cc @@ -0,0 +1,87 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/tensor/scatter_nd.h" +#include "core/providers/cuda/tensor/scatter_nd_impl.h" +#include "core/providers/cuda/shared_inc/cuda_utils.h" +#include "core/providers/cpu/tensor/utils.h" + +namespace onnxruntime { +namespace cuda { + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterND, + kOnnxDomain, + 11, 12, + kCudaExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) + .MayInplace(0, 0), + ScatterND); + +ONNX_OPERATOR_KERNEL_EX(ScatterND, + kOnnxDomain, + 13, + kCudaExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) + .MayInplace(0, 0), + ScatterND); + +Status ScatterND::ComputeInternal(OpKernelContext* context) const { + const auto* input_tensor = context->Input(0); + const auto* indices_tensor = context->Input(1); + const auto* updates_tensor = context->Input(2); + + const auto& input_shape = input_tensor->Shape(); + const auto& indices_shape = indices_tensor->Shape(); + const auto& updates_shape = updates_tensor->Shape(); + + // Validate input shapes + ValidateShapes(input_shape, indices_shape, updates_shape); + + auto* output_tensor = context->Output(0, input_shape); + + const void* input_data = input_tensor->DataRaw(); + void* output_data = output_tensor->MutableDataRaw(); + + size_t element_size = input_tensor->DataType()->Size(); + + if (input_data != output_data) { + // TODO: Run benchmarks to determine if a dedicated kernel doing data copy will be faster than invoking cudaMemcpy ? + cudaMemcpyAsync(output_data, input_data, element_size * input_shape.Size(), cudaMemcpyDeviceToDevice); + } + + // Bail out early + if (indices_shape.Size() == 0) { + return Status::OK(); + } + + auto last_index_dimension = indices_shape[indices_shape.NumDimensions() - 1]; + + // We need element counts for each dimension and the input dim value for each dimension + // for the range [0, last_index_dimension). + // To avoid multiple GPU data transfers, we combine this into one array and send it through + TensorPitches input_strides(input_shape); + std::vector element_counts_and_input_dims(last_index_dimension * 2, 0LL); + for (int64_t i = 0; i < last_index_dimension; ++i) { + element_counts_and_input_dims[i] = input_strides[i]; + element_counts_and_input_dims[i + last_index_dimension] = input_shape[i]; + } + CudaAsyncBuffer element_counts_and_input_dims_gpu(this, element_counts_and_input_dims); + element_counts_and_input_dims_gpu.CopyToGpu(); + + ORT_RETURN_IF_ERROR(ScatterNDImpl( + output_data, + element_size, + indices_shape.Size() / static_cast(last_index_dimension), + indices_tensor->Data(), // only int64_t is supported for indices as per the onnx spec + last_index_dimension, + element_counts_and_input_dims_gpu.GpuPtr(), + updates_tensor->DataRaw(), + input_shape.SizeFromDimension(last_index_dimension))); + + return Status::OK(); +} + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_nd.h b/onnxruntime/core/providers/cuda/tensor/scatter_nd.h new file mode 100644 index 0000000000..89bd6104a5 --- /dev/null +++ b/onnxruntime/core/providers/cuda/tensor/scatter_nd.h @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/providers/cuda/cuda_kernel.h" +#include "core/providers/cpu/tensor/scatter_nd.h" + +namespace onnxruntime { +namespace cuda { + +class ScatterND final : public CudaKernel, protected ScatterNDBase { + public: + explicit ScatterND(const OpKernelInfo& info) : CudaKernel(info) {} + Status ComputeInternal(OpKernelContext* context) const override; +}; + +} // namespace cuda +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_nd_impl.cu b/onnxruntime/core/providers/cuda/tensor/scatter_nd_impl.cu new file mode 100644 index 0000000000..213e8d9ed2 --- /dev/null +++ b/onnxruntime/core/providers/cuda/tensor/scatter_nd_impl.cu @@ -0,0 +1,123 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/tensor/scatter_nd_impl.h" +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/atomic/common.cuh" + +namespace onnxruntime { +namespace cuda { + +template +__global__ void _ScatterNDKernel( + T* output_data, + const size_t num_indices, + const int64_t* indices_data, + const int64_t last_index_dimension, + const int64_t* element_counts_and_input_dims, + const T* updates_data, + const size_t num_updates_elements) { + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, num_indices); + + // Compute the base offset into the output data + int64_t data_offset = 0; + + size_t indices_start = last_index_dimension * id; + size_t indices_end = indices_start + last_index_dimension; + for (size_t i = indices_start; i < indices_end; ++i) { + int64_t index = indices_data[i]; + + int64_t element_count_dim = element_counts_and_input_dims[i - indices_start]; + int64_t dim_value = element_counts_and_input_dims[i - indices_start + last_index_dimension]; + + // Clamp the index if out of range + // This would have been an error in the CPU kernel, but throwing in the CUDA EP + // is hard. This is the approach taken by other frameworks for out of bound indices + // in their corresponding GPU backends as well. + if (index < 0) + index = 0; + + else if (index >= dim_value) + index = dim_value - 1; + + data_offset += (index * element_count_dim); + } + + const T* updates_data_base = updates_data + num_updates_elements * id; + T* output_data_base = output_data + data_offset; + + for (size_t i = 0; i < num_updates_elements; ++i) { + output_data_base[i] = updates_data_base[i]; + } +} + +Status ScatterNDImpl( + void* output_data, + const size_t element_size, + const size_t num_indices, + const int64_t* indices_data, + const int64_t last_index_dimension, + const int64_t* element_counts_and_input_dims, + const void* updates_data, + const size_t num_updates_elements) { + if (num_indices == 0) + return Status::OK(); + + // Parallelize on number of indices + int blocksPerGrid = static_cast(ceil(static_cast(num_indices) / GridDim::maxThreadsPerBlock)); + + switch (element_size) { + case sizeof(int8_t): + _ScatterNDKernel<<>>( + reinterpret_cast(output_data), + num_indices, + indices_data, + last_index_dimension, + element_counts_and_input_dims, + reinterpret_cast(updates_data), + num_updates_elements); + break; + + case sizeof(int16_t): + _ScatterNDKernel<<>>( + reinterpret_cast(output_data), + num_indices, + indices_data, + last_index_dimension, + element_counts_and_input_dims, + reinterpret_cast(updates_data), + num_updates_elements); + break; + + case sizeof(int32_t): + _ScatterNDKernel<<>>( + reinterpret_cast(output_data), + num_indices, + indices_data, + last_index_dimension, + element_counts_and_input_dims, + reinterpret_cast(updates_data), + num_updates_elements); + break; + + case sizeof(int64_t): + _ScatterNDKernel<<>>( + reinterpret_cast(output_data), + num_indices, + indices_data, + last_index_dimension, + element_counts_and_input_dims, + reinterpret_cast(updates_data), + num_updates_elements); + break; + + default: + // Shouldn't hit this + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for ScatterND operator"); + } + + return Status::OK(); +} + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_nd_impl.h b/onnxruntime/core/providers/cuda/tensor/scatter_nd_impl.h new file mode 100644 index 0000000000..de9bad886d --- /dev/null +++ b/onnxruntime/core/providers/cuda/tensor/scatter_nd_impl.h @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace cuda { + +Status ScatterNDImpl( + void* output_data, + const size_t element_size, + const size_t num_indices, + const int64_t* indices_data, + const int64_t last_index_dimension, + const int64_t* element_counts_and_input_dims, + const void* updates_data, + const size_t num_updates_elements); + +} // namespace cuda +} // namespace onnxruntime