diff --git a/onnxruntime/core/providers/cpu/tensor/gather_nd.cc b/onnxruntime/core/providers/cpu/tensor/gather_nd.cc index ae0e77e835..e92f456c86 100644 --- a/onnxruntime/core/providers/cpu/tensor/gather_nd.cc +++ b/onnxruntime/core/providers/cpu/tensor/gather_nd.cc @@ -26,7 +26,7 @@ ONNX_OPERATOR_KERNEL_EX(GatherND, kMSDomain, 1, kCpuExecutionProvider, #endif ONNX_CPU_OPERATOR_KERNEL( - GatherND, + GatherND, 11, KernelDefBuilder() .TypeConstraint("T", DataTypeImpl::AllTensorTypes()) @@ -46,7 +46,7 @@ template Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) const { const auto* input_tensor = context->Input(0); const auto* indices_tensor = context->Input(1); - ORT_ENFORCE(input_tensor != nullptr && indices_tensor != nullptr, "GatherND op: Input count mismatch"); + ORT_ENFORCE(input_tensor != nullptr && indices_tensor != nullptr, "GatherNDBase PrepareForCompute: Input count mismatch"); const auto& input_shape = input_tensor->Shape(); const auto& indices_shape = indices_tensor->Shape(); @@ -54,7 +54,14 @@ Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) con return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "indices tensor must has rank larger than 0"); } - int64_t last_indices_dimension = indices_shape[indices_shape.NumDimensions() - 1] + batch_dims_; + const auto num_slice_dims = indices_shape[indices_shape.NumDimensions() - 1]; + const auto num_slices = indices_shape.SizeToDimension(indices_shape.NumDimensions() - 1); + const auto slice_size = input_shape.SizeFromDimension(batch_dims_ + num_slice_dims); + const auto num_batches = input_shape.SizeToDimension(batch_dims_); + const auto input_batch_stride = input_shape.SizeFromDimension(batch_dims_); + const auto num_slices_per_batch = num_slices / num_batches; + + int64_t last_indices_dimension = batch_dims_ + num_slice_dims; if (last_indices_dimension > static_cast(input_shape.NumDimensions())) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "last dimension of indices must not be larger than rank of input tensor"); @@ -63,31 +70,21 @@ Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) con std::vector shape(indices_shape.GetDims().begin(), indices_shape.GetDims().end() - 1); shape.insert(shape.end(), input_shape.GetDims().begin() + last_indices_dimension, input_shape.GetDims().end()); auto* output_tensor = context->Output(0, TensorShape(std::move(shape))); - std::vector element_counts(last_indices_dimension + batch_dims_, - 0LL); // Number of elements for each input dimension + std::vector sizes_from_slice_dims(num_slice_dims); #ifdef _OPENMP #pragma omp parallel for #endif - for (int64_t i = 0; i < last_indices_dimension; ++i) { - element_counts[i] = input_shape.SizeFromDimension(i + 1); - } - - auto last_dim_size = indices_shape.SizeFromDimension(indices_shape.NumDimensions() - 1); -#ifdef USE_OPENMP -#pragma omp parallel for -#endif - for (int64_t i = batch_dims_ - 1; i >= 0; --i) { - element_counts[last_indices_dimension + i] = indices_shape.SizeFromDimension(i + 1) / last_dim_size; + for (int64_t i = 0; i < num_slice_dims; ++i) { + sizes_from_slice_dims[i] = input_shape.SizeFromDimension(batch_dims_ + i + 1); } int64_t err_index = 0; p.element_bytes = input_tensor->DataType()->Size(); - p.element_to_copy = input_shape.SizeFromDimension(last_indices_dimension); - p.bytes_to_copy = p.element_bytes * p.element_to_copy; - const auto* indice_offset = indices_tensor->Data(); - const int64_t offset_count = indices_shape.Size() / (last_indices_dimension - batch_dims_); // Times to copy - p.element_offsets.assign(offset_count, 0LL); + p.element_count_per_slice = slice_size; + p.bytes_per_slice = p.element_bytes * p.element_count_per_slice; + const auto* indices_data = indices_tensor->Data(); + p.slice_offsets.assign(num_slices, 0LL); if (input_tensor->IsDataTypeString()) { p.input_str_base = static_cast(input_tensor->DataRaw()); @@ -97,29 +94,30 @@ Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) con p.output_base = static_cast(output_tensor->MutableDataRaw()); } - //Compute the element_offset + // Compute the element_offset #ifdef _OPENMP #pragma omp parallel for #endif - for (int64_t i = 0; i < offset_count; ++i) { - int64_t reminder = i; - for (int64_t j = 0; j < batch_dims_; ++j) { - int64_t idx = reminder / element_counts[last_indices_dimension + j]; - p.element_offsets[i] += idx * element_counts[j]; - reminder -= (idx * element_counts[last_indices_dimension + j]); - } - for (int64_t j = batch_dims_; j < last_indices_dimension; ++j) { - auto index = *(indice_offset + i * (last_indices_dimension - batch_dims_) + (j - batch_dims_)); - auto upper_limit = input_shape[j]; - auto lower_limit = -upper_limit; + for (int64_t slice_idx = 0; slice_idx < num_slices; ++slice_idx) { + const size_t batch_idx = slice_idx / num_slices_per_batch; + const size_t input_base_offset = batch_idx * input_batch_stride; + + const auto* const slice_indices = indices_data + slice_idx * num_slice_dims; + size_t relative_slice_offset = 0; + for (int64_t dim_idx = 0; dim_idx < num_slice_dims; ++dim_idx) { + int64_t index = static_cast(slice_indices[dim_idx]); + const auto upper_limit = input_shape[batch_dims_ + dim_idx]; + const auto lower_limit = -upper_limit; if (index < lower_limit || index >= upper_limit) { err_index = index; + break; } - if (index < 0) { - index += static_cast(upper_limit); - } - p.element_offsets[i] += index * element_counts[j]; + if (index < 0) index += upper_limit; + + relative_slice_offset += index * sizes_from_slice_dims[dim_idx]; } + + p.slice_offsets[slice_idx] = input_base_offset + relative_slice_offset; } return err_index == 0 ? Status::OK() @@ -142,9 +140,9 @@ Status GatherND::GatherNumber(const Prepare& p) const { #ifdef _OPENMP #pragma omp parallel for #endif - for (int64_t i = 0; i < static_cast(p.element_offsets.size()); ++i) { - memcpy(p.output_base + i * p.bytes_to_copy, p.input_base + p.element_offsets[i] * p.element_bytes, - p.bytes_to_copy); + for (int64_t slice_idx = 0; slice_idx < static_cast(p.slice_offsets.size()); ++slice_idx) { + memcpy(p.output_base + slice_idx * p.bytes_per_slice, p.input_base + p.slice_offsets[slice_idx] * p.element_bytes, + p.bytes_per_slice); } return Status::OK(); @@ -154,9 +152,10 @@ Status GatherND::GatherString(const Prepare& p) const { #ifdef _OPENMP #pragma omp parallel for #endif - for (int64_t i = 0; i < static_cast(p.element_offsets.size()); ++i) { - for (int64_t j = 0; j < static_cast(p.element_to_copy); ++j) { - p.output_str_base[i * p.element_to_copy + j] = p.input_str_base[p.element_offsets[i] + j]; + for (int64_t slice_idx = 0; slice_idx < static_cast(p.slice_offsets.size()); ++slice_idx) { + const int64_t slice_base_offset = slice_idx * p.element_count_per_slice; + for (int64_t j = 0; j < static_cast(p.element_count_per_slice); ++j) { + p.output_str_base[slice_base_offset + j] = p.input_str_base[p.slice_offsets[slice_idx] + j]; } } diff --git a/onnxruntime/core/providers/cpu/tensor/gather_nd.h b/onnxruntime/core/providers/cpu/tensor/gather_nd.h index 135706a245..d3765751bc 100644 --- a/onnxruntime/core/providers/cpu/tensor/gather_nd.h +++ b/onnxruntime/core/providers/cpu/tensor/gather_nd.h @@ -16,19 +16,19 @@ class GatherNDBase { const std::string* input_str_base; uint8_t* output_base; std::string* output_str_base; - uint64_t bytes_to_copy; + uint64_t bytes_per_slice; uint64_t element_bytes; - uint64_t element_to_copy; - std::vector element_offsets; + uint64_t element_count_per_slice; + std::vector slice_offsets; Prepare() : input_base(nullptr), input_str_base(nullptr), output_base(nullptr), output_str_base(nullptr), - bytes_to_copy(0), + bytes_per_slice(0), element_bytes(0), - element_to_copy(0), - element_offsets(0) {} + element_count_per_slice(0), + slice_offsets(0) {} }; // struct Prepare template diff --git a/onnxruntime/core/providers/cuda/tensor/gather_nd.cc b/onnxruntime/core/providers/cuda/tensor/gather_nd.cc index 53e2fcd153..6829405bfd 100644 --- a/onnxruntime/core/providers/cuda/tensor/gather_nd.cc +++ b/onnxruntime/core/providers/cuda/tensor/gather_nd.cc @@ -38,40 +38,23 @@ Status CheckBatchDimensionsMatch( return Status::OK(); } -#define TYPED_FUNCTION_CALL_FWD(T) \ - if (T_type == DataTypeImpl::GetType()) { \ - GatherNDImpl::MappedType>(num_slices, kernel_input_data, kernel_output_data, slice_size, input_slice_offsets_buffer.get()); \ - } - -#define TYPED_FUNCTION_CALL_BWD(T) \ - if (T_type == DataTypeImpl::GetType()) { \ - GatherNDGradImpl::MappedType>(num_slices, kernel_input_data, kernel_output_data, slice_size, input_slice_offsets_buffer.get()); \ - } - - template -Status GatherNDBase::CommonComputeKernel( +Status GatherNDBase::PrepareCompute( const int64_t batch_dims, const TensorShape& input_shape, - const Tensor* kernel_input_tensor, - Tensor* kernel_output_tensor, const TensorShape& indices_shape, const Tensor* indices_tensor, - const bool fwd) const { - // Note on naming: - // `input` refers to the GatherND `data` input, while `kernel_input` refers to - // what the GatherND[Grad] CUDA kernel accepts as input. - + int64_t& num_slices, + int64_t& slice_size, + IAllocatorUniquePtr& input_slice_offsets_buffer) const { const auto num_slice_dims = indices_shape[indices_shape.NumDimensions() - 1]; - const auto num_slices = indices_shape.SizeToDimension(indices_shape.NumDimensions() - 1); - const auto slice_size = input_shape.SizeFromDimension(batch_dims + num_slice_dims); + num_slices = indices_shape.SizeToDimension(indices_shape.NumDimensions() - 1); + slice_size = input_shape.SizeFromDimension(batch_dims + num_slice_dims); const auto num_batches = input_shape.SizeToDimension(batch_dims); const auto input_batch_stride = input_shape.SizeFromDimension(batch_dims); const auto num_slices_per_batch = num_slices / num_batches; const TIndex* const indices_data = indices_tensor->Data(); - const void* const kernel_input_data = kernel_input_tensor->DataRaw(); - void* const kernel_output_data = kernel_output_tensor->MutableDataRaw(); std::vector sizes_from_slice_dims(num_slice_dims); { @@ -89,10 +72,10 @@ Status GatherNDBase::CommonComputeKernel( sizes_from_slice_dims.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); - auto input_slice_offsets_buffer = GetScratchBuffer(num_slices); + input_slice_offsets_buffer = GetScratchBuffer(num_slices); TArray input_dims(input_shape.GetDims()); - // TODO reuse computed slice offsets from GatherND in GatherNDGrad + ComputeSliceOffsetsImpl( batch_dims, input_dims, @@ -104,39 +87,19 @@ Status GatherNDBase::CommonComputeKernel( indices_data, input_slice_offsets_buffer.get()); - if (fwd) { - MLDataType T_type = kernel_input_tensor->DataType(); - TYPED_FUNCTION_CALL_FWD(float); - TYPED_FUNCTION_CALL_FWD(MLFloat16); - TYPED_FUNCTION_CALL_FWD(double); - } else { -#ifdef ENABLE_TRAINING - MLDataType T_type = kernel_input_tensor->DataType(); - TYPED_FUNCTION_CALL_BWD(float); - TYPED_FUNCTION_CALL_BWD(MLFloat16); - TYPED_FUNCTION_CALL_BWD(double); -#else - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Gradient computation is only supported in the training mode."); -#endif - } - return Status::OK(); } -#undef TYPED_FUNCTION_CALL_FWD -#undef TYPED_FUNCTION_CALL_BWD - -#define REGISTER_KERNEL_TYPED_GATHER_ND(TIndex, ver) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - GatherND, \ - kOnnxDomain, \ - ver, \ - TIndex, \ - kCudaExecutionProvider, \ - KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), \ - DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) \ - .TypeConstraint("Tind", DataTypeImpl::GetTensorType()), \ +#define REGISTER_KERNEL_TYPED_GATHER_ND(TIndex, ver) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + GatherND, \ + kOnnxDomain, \ + ver, \ + TIndex, \ + kCudaExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::AllIEEEFloatTensorTypes()) \ + .TypeConstraint("Tind", DataTypeImpl::GetTensorType()), \ GatherND); // TODO: decprecate GatherND-1 after updating training models to opset-12 @@ -145,6 +108,20 @@ REGISTER_KERNEL_TYPED_GATHER_ND(int64_t, 1) #endif REGISTER_KERNEL_TYPED_GATHER_ND(int64_t, 12) +template +struct GatherNDComputeImpl { + void operator()(const int64_t num_slices, + const int64_t slice_size, + const void* const kernel_input_data, + void* const kernel_output_data, + int64_t* const input_slice_offsets_data) const { + typedef typename ToCudaType::MappedType CudaT; + GatherNDImpl(num_slices, kernel_input_data, + kernel_output_data, slice_size, + input_slice_offsets_data); + } +}; + template Status GatherND::ComputeInternal(OpKernelContext* context) const { auto input_tensor = context->Input(0); @@ -169,16 +146,26 @@ Status GatherND::ComputeInternal(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(CheckBatchDimensionsMatch( static_cast(batch_dims_), {input_shape, indices_shape})); - //Output shape + // Output shape std::vector shape(indices_shape.GetDims().begin(), indices_shape.GetDims().end() - 1); shape.insert(shape.end(), input_shape.GetDims().begin() + last_indices_dimension, input_shape.GetDims().end()); auto output_tensor = context->Output(0, TensorShape(shape)); - //Compute - auto status = CommonComputeKernel(batch_dims_, input_shape, input_tensor, output_tensor, indices_shape, indices_tensor, true); + // Compute + int64_t num_slices; + int64_t slice_size; + IAllocatorUniquePtr input_slice_offsets_buffer; + ORT_RETURN_IF_ERROR(PrepareCompute(batch_dims_, input_shape, indices_shape, indices_tensor, + num_slices, slice_size, input_slice_offsets_buffer)); - return status; + const void* const kernel_input_data = input_tensor->DataRaw(); + void* const kernel_output_data = output_tensor->MutableDataRaw(); + utils::MLTypeCallDispatcher + t_disp(input_tensor->GetElementType()); + t_disp.Invoke(num_slices, slice_size, kernel_input_data, kernel_output_data, input_slice_offsets_buffer.get()); + + return Status::OK(); } } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/tensor/gather_nd.h b/onnxruntime/core/providers/cuda/tensor/gather_nd.h index 9e5f5116a5..e29f6532a6 100644 --- a/onnxruntime/core/providers/cuda/tensor/gather_nd.h +++ b/onnxruntime/core/providers/cuda/tensor/gather_nd.h @@ -23,14 +23,14 @@ class GatherNDBase : public CudaKernel { protected: template - Status CommonComputeKernel( + Status PrepareCompute( const int64_t batch_dims, const TensorShape& input_shape, - const Tensor* input_tensor, - Tensor* output_tensor, const TensorShape& indices_shape, const Tensor* indices_tensor, - const bool fwd) const; + int64_t& num_slices, + int64_t& slice_size, + IAllocatorUniquePtr& input_slice_offsets_buffer) const; int64_t batch_dims_; }; diff --git a/onnxruntime/core/providers/cuda/tensor/gather_nd_impl.cu b/onnxruntime/core/providers/cuda/tensor/gather_nd_impl.cu index 91b0169e68..17dba1e402 100644 --- a/onnxruntime/core/providers/cuda/tensor/gather_nd_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/gather_nd_impl.cu @@ -2,7 +2,6 @@ // Licensed under the MIT License. #include "core/providers/cuda/tensor/gather_nd_impl.h" - #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/atomic/common.cuh" diff --git a/onnxruntime/test/common/cuda_op_test_utils.h b/onnxruntime/test/common/cuda_op_test_utils.h index 520e65798c..cf72f66e32 100644 --- a/onnxruntime/test/common/cuda_op_test_utils.h +++ b/onnxruntime/test/common/cuda_op_test_utils.h @@ -37,5 +37,12 @@ inline bool HasCudaEnvironment(int min_cuda_architecture) { return cuda_architecture >= min_cuda_architecture; } +inline bool NeedSkipIfCudaArchLowerThan(int min_cuda_architecture) { + // only skip when CUDA ep is enabled. + if (DefaultCudaExecutionProvider().get() != nullptr) { + return !HasCudaEnvironment(min_cuda_architecture); + } + return false; +} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/tensor/gather_nd_op_test.cc b/onnxruntime/test/providers/cpu/tensor/gather_nd_op_test.cc index 7b15687120..6cfbeb2e14 100644 --- a/onnxruntime/test/providers/cpu/tensor/gather_nd_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/gather_nd_op_test.cc @@ -79,7 +79,9 @@ TEST(GatherNDOpTest, int64_t) { } TEST(GatherNDOpTest, float) { - if (!HasCudaEnvironment(600 /*min_cuda_architecture*/)) return; + if (NeedSkipIfCudaArchLowerThan(600)) { + return; + } RunTest({2, 2}, {0.0f, 0.1f, 0.2f, 0.3f}, {2, 1}, {1LL, 0LL}, {2, 2}, {0.2f, 0.3f, 0.0f, 0.1f}); @@ -88,7 +90,9 @@ TEST(GatherNDOpTest, float) { } TEST(GatherNDOpTest, double) { - if (!HasCudaEnvironment(600 /*min_cuda_architecture*/)) return; + if (NeedSkipIfCudaArchLowerThan(600)) { + return; + } RunTest({2, 2}, {0.0, 0.1, 0.2, 0.3}, {2, 1}, {1LL, 0LL}, {2, 2}, {0.2, 0.3, 0.0, 0.1}); @@ -133,7 +137,15 @@ TEST(GatherNDOpTest, ContribOpInt32Indices) { #endif -TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_0) { +TEST(GatherNDOpTest, GatherND_slice_float_default_batch_dims) { + OpTester test("GatherND", 12, kOnnxDomain); + test.AddInput("data", {2, 3, 4}, ValueRange(24, 1.0f)); + test.AddInput("indices", {3, 2, 2}, {0LL, 1LL, 0LL, 2LL, 1LL, 0LL, 0LL, 0LL, 1LL, 1LL, 1LL, 2LL}); + test.AddOutput("output", {3, 2, 4}, {5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 1.0, 2.0, 3.0, 4.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}); + test.Run(); +} + +TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_zero) { OpTester test("GatherND", 12, kOnnxDomain); test.AddAttribute("batch_dims", 0); test.AddInput("data", {2, 3, 4}, ValueRange(24, 1.0f)); @@ -142,7 +154,7 @@ TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_0) { test.Run(); } -TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_1) { +TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_one_1) { OpTester test("GatherND", 12, kOnnxDomain); test.AddAttribute("batch_dims", 1); test.AddInput("data", {2, 3, 4}, ValueRange(24, 1.0f)); @@ -151,7 +163,7 @@ TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_1) { test.Run(); } -TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_2) { +TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_one_2) { OpTester test("GatherND", 12, kOnnxDomain); test.AddAttribute("batch_dims", 1); test.AddInput("data", {2, 2, 2}, ValueRange(8, 0.0f, 0.1f)); @@ -160,28 +172,7 @@ TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_2) { test.Run(); } -#ifdef USE_CUDA -#if __CUDA_ARCH__ >= 600 -TEST(GatherNDOpTest, GatherND_slice_double_batch_dims_3) { - OpTester test("GatherND", 12, kOnnxDomain); - test.AddAttribute("batch_dims", 1); - test.AddInput("data", {2, 2, 2}, ValueRange(8, 0.0f, 0.1f)); - test.AddInput("indices", {2, 1, 1}, {1LL, 0LL}); - test.AddOutput("output", {2, 1, 2}, {0.2f, 0.3f, 0.4f, 0.5f}); - test.Run(); -} - -TEST(GatherNDOpTest, GatherND_slice_double) { - OpTester test("GatherND", 12, kOnnxDomain); - test.AddInput("data", {2, 2}, {0.0f, 0.1f, 0.2f, 0.3f}); - test.AddInput("indices", {2, 1}, {1LL, 0LL}); - test.AddOutput("output", {2, 2}, {0.2f, 0.3f, 0.0f, 0.1f}); - test.Run(); -} -#endif -#endif - -TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_4) { +TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_one_3) { OpTester test("GatherND", 12, kOnnxDomain); test.AddAttribute("batch_dims", 1); test.AddInput("data", {2, 2, 2}, ValueRange(8, 0.0f, 0.1f)); @@ -190,10 +181,74 @@ TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_4) { test.Run(); } -#ifdef USE_CUDA +TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_two) { + OpTester test("GatherND", 12, kOnnxDomain); + test.AddAttribute("batch_dims", 2); + test.AddInput("data", {2, 1, 3, 5}, ValueRange(30, 0.0f, 0.1f)); + test.AddInput("indices", {2, 1, 3, 2}, + {0LL, 0LL, + 0LL, 1LL, + 1LL, 0LL, + 1LL, 1LL, + 0LL, 4LL, + 2LL, 4LL}); + test.AddOutput("output", {2, 1, 3}, {0.0f, 0.1f, 0.5f, 2.1f, 1.9f, 2.9f}); + test.Run(); +} -TEST(GatherNDOpTest, GatherND_slice_double_batch_dims_3) { - if (!HasCudaEnvironment(600 /*min_cuda_architecture*/)) return; +TEST(GatherNDOpTest, GatherND_negative_slice_float_batch_dims_one) { + OpTester test("GatherND", 12, kOnnxDomain); + test.AddAttribute("batch_dims", 1); + test.AddInput("data", {2, 3, 4}, ValueRange(24, 1.0f)); + test.AddInput("indices", {2, 2, 2}, {0LL, -3LL, -1LL, 2LL, -1LL, 0LL, 0LL, -2LL}); + test.AddOutput("output", {2, 2}, {2.0, 11.0, 21.0, 15.0}); + test.Run(); +} + +TEST(GatherNDOpTest, GatherND_negative_slice_float_batch_dims_two) { + OpTester test("GatherND", 12, kOnnxDomain); + test.AddAttribute("batch_dims", 2); + test.AddInput("data", {2, 1, 3, 5}, ValueRange(30, 0.0f, 0.1f)); + test.AddInput("indices", {2, 1, 3, 2}, + {0LL, -5LL, + -3LL, 1LL, + -2LL, 0LL, + -2LL, -4LL, + 0LL, -1LL, + 2LL, -1LL}); + test.AddOutput("output", {2, 1, 3}, {0.0f, 0.1f, 0.5f, 2.1f, 1.9f, 2.9f}); + test.Run(); +} + +TEST(GatherNDOpTest, GatherND_slice_double_batch_dims_one_1) { + if (NeedSkipIfCudaArchLowerThan(600)) { + return; + } + + OpTester test("GatherND", 12, kOnnxDomain); + test.AddAttribute("batch_dims", 1); + test.AddInput("data", {2, 2, 2}, ValueRange(8, 0.0, 0.1)); + test.AddInput("indices", {2, 1, 1}, {1LL, 0LL}); + test.AddOutput("output", {2, 1, 2}, {0.2f, 0.3f, 0.4f, 0.5f}); + test.Run(); +} + +TEST(GatherNDOpTest, GatherND_slice_double_default_batch_dims) { + if (NeedSkipIfCudaArchLowerThan(600)) { + return; + } + + OpTester test("GatherND", 12, kOnnxDomain); + test.AddInput("data", {2, 2}, {0.0f, 0.1f, 0.2f, 0.3f}); + test.AddInput("indices", {2, 1}, {1LL, 0LL}); + test.AddOutput("output", {2, 2}, {0.2f, 0.3f, 0.0f, 0.1f}); + test.Run(); +} // namespace test + +TEST(GatherNDOpTest, GatherND_slice_double_batch_dims_one_2) { + if (NeedSkipIfCudaArchLowerThan(600)) { + return; + } OpTester test("GatherND", 12, kOnnxDomain); test.AddAttribute("batch_dims", 1); @@ -204,7 +259,9 @@ TEST(GatherNDOpTest, GatherND_slice_double_batch_dims_3) { } TEST(GatherNDOpTest, GatherND_slice_half) { - if (!HasCudaEnvironment(600 /*min_cuda_architecture*/)) return; + if (NeedSkipIfCudaArchLowerThan(600)) { + return; + } OpTester test("GatherND", 12, kOnnxDomain); std::vector data_f({0.0f, 0.1f, 0.2f, 0.3f}); @@ -242,7 +299,5 @@ TEST(GatherNDOpTest, GatherND_batch_dims_of_2) { test.Run(); } -#endif - } // namespace test } // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/tensor/gather_nd_grad.cc b/orttraining/orttraining/training_ops/cuda/tensor/gather_nd_grad.cc index d7878cf0a6..8c9fb54fdf 100644 --- a/orttraining/orttraining/training_ops/cuda/tensor/gather_nd_grad.cc +++ b/orttraining/orttraining/training_ops/cuda/tensor/gather_nd_grad.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "orttraining/training_ops/cuda/tensor/gather_nd_grad.h" +#include "orttraining/training_ops/cuda/tensor/gather_nd_grad_impl.h" #include "core/providers/cuda/shared_inc/cuda_utils.h" namespace onnxruntime { @@ -23,6 +24,20 @@ namespace cuda { REGISTER_KERNEL_TYPED_GATHER_ND_GRAD(int64_t) +template +struct GatherNDGradComputeImpl { + void operator()(const int64_t num_slices, + const int64_t slice_size, + const void* const kernel_input_data, + void* const kernel_output_data, + int64_t* const input_slice_offsets_data) const { + typedef typename ToCudaType::MappedType CudaT; + GatherNDGradImpl(num_slices, kernel_input_data, + kernel_output_data, slice_size, + input_slice_offsets_data); + } +}; + template Status GatherNDGrad::ComputeInternal(OpKernelContext* context) const { auto shape_tensor = context->Input(0); @@ -42,7 +57,7 @@ Status GatherNDGrad::ComputeInternal(OpKernelContext* context) const { auto last_indices_dimension = batch_dims_ + indices_shape[indices_shape.NumDimensions() - 1]; - //Output + // Output auto shape_data = shape_tensor->Data(); auto input_shape = TensorShape(shape_data, shape_tensor->SizeInBytes() / sizeof(shape_tensor->DataType())); @@ -59,8 +74,20 @@ Status GatherNDGrad::ComputeInternal(OpKernelContext* context) const { // TODO this memset can be expensive, a sparse tensor representation would help here CUDA_RETURN_IF_ERROR(cudaMemsetAsync(output_tensor->MutableDataRaw(), 0, output_tensor->SizeInBytes())); - auto status = CommonComputeKernel(batch_dims_, input_shape, update_tensor, output_tensor, indices_shape, indices_tensor, false); - return status; + // Compute + int64_t num_slices; + int64_t slice_size; + IAllocatorUniquePtr input_slice_offsets_buffer; + ORT_RETURN_IF_ERROR(PrepareCompute(batch_dims_, input_shape, indices_shape, indices_tensor, + num_slices, slice_size, input_slice_offsets_buffer)); + + const void* const kernel_input_data = update_tensor->DataRaw(); + void* const kernel_output_data = output_tensor->MutableDataRaw(); + utils::MLTypeCallDispatcher + t_disp(update_tensor->GetElementType()); + t_disp.Invoke(num_slices, slice_size, kernel_input_data, kernel_output_data, input_slice_offsets_buffer.get()); + + return Status::OK(); } } // namespace cuda diff --git a/orttraining/orttraining/training_ops/cuda/tensor/gather_nd_gard_impl.cu b/orttraining/orttraining/training_ops/cuda/tensor/gather_nd_grad_impl.cu similarity index 95% rename from orttraining/orttraining/training_ops/cuda/tensor/gather_nd_gard_impl.cu rename to orttraining/orttraining/training_ops/cuda/tensor/gather_nd_grad_impl.cu index 8c37836705..8abad7e30e 100644 --- a/orttraining/orttraining/training_ops/cuda/tensor/gather_nd_gard_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/tensor/gather_nd_grad_impl.cu @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/cuda/tensor/gather_nd_impl.h" +#include "orttraining/training_ops/cuda/tensor/gather_nd_grad_impl.h" #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/atomic/common.cuh" diff --git a/orttraining/orttraining/training_ops/cuda/tensor/gather_nd_grad_impl.h b/orttraining/orttraining/training_ops/cuda/tensor/gather_nd_grad_impl.h new file mode 100644 index 0000000000..3b19e758e2 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/tensor/gather_nd_grad_impl.h @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace cuda { + +template +void GatherNDGradImpl( + const size_t num_slices, + const void* update_data, + void* output_data, + const size_t slice_size, + const int64_t* input_slice_offsets_data); + +} // namespace cuda +} // namespace onnxruntime