From fdd0926d0050b85c822f2231dd1d260ddd5449a8 Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Sat, 22 Aug 2020 08:00:31 +0800 Subject: [PATCH] int64_t support for GatherND cuda (#4881) Co-authored-by: Vincent Wang --- onnxruntime/core/providers/cuda/tensor/gather_nd.cc | 10 ++++++++-- .../core/providers/cuda/tensor/gather_nd_impl.cu | 1 + .../test/providers/cpu/tensor/gather_nd_op_test.cc | 10 ++++++++++ 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/cuda/tensor/gather_nd.cc b/onnxruntime/core/providers/cuda/tensor/gather_nd.cc index 6829405bfd..20e10e7e4f 100644 --- a/onnxruntime/core/providers/cuda/tensor/gather_nd.cc +++ b/onnxruntime/core/providers/cuda/tensor/gather_nd.cc @@ -98,7 +98,13 @@ Status GatherNDBase::PrepareCompute( TIndex, \ kCudaExecutionProvider, \ KernelDefBuilder() \ - .TypeConstraint("T", DataTypeImpl::AllIEEEFloatTensorTypes()) \ + .TypeConstraint("T", \ + std::vector{ \ + DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType(), \ + }) \ .TypeConstraint("Tind", DataTypeImpl::GetTensorType()), \ GatherND); @@ -161,7 +167,7 @@ Status GatherND::ComputeInternal(OpKernelContext* context) const { const void* const kernel_input_data = input_tensor->DataRaw(); void* const kernel_output_data = output_tensor->MutableDataRaw(); - utils::MLTypeCallDispatcher + utils::MLTypeCallDispatcher t_disp(input_tensor->GetElementType()); t_disp.Invoke(num_slices, slice_size, kernel_input_data, kernel_output_data, input_slice_offsets_buffer.get()); diff --git a/onnxruntime/core/providers/cuda/tensor/gather_nd_impl.cu b/onnxruntime/core/providers/cuda/tensor/gather_nd_impl.cu index 17dba1e402..dfb3a0aff9 100644 --- a/onnxruntime/core/providers/cuda/tensor/gather_nd_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/gather_nd_impl.cu @@ -105,6 +105,7 @@ SPECIALIZED_COMPUTE_SLICE_OFFSETS_IMPL(int32_t) SPECIALIZED_COMPUTE_SLICE_OFFSETS_IMPL(int64_t) SPECIALIZED_IMPL(float) +SPECIALIZED_IMPL(int64_t) #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600 SPECIALIZED_IMPL(half) SPECIALIZED_IMPL(double) diff --git a/onnxruntime/test/providers/cpu/tensor/gather_nd_op_test.cc b/onnxruntime/test/providers/cpu/tensor/gather_nd_op_test.cc index 6cfbeb2e14..f498db7637 100644 --- a/onnxruntime/test/providers/cpu/tensor/gather_nd_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/gather_nd_op_test.cc @@ -299,5 +299,15 @@ TEST(GatherNDOpTest, GatherND_batch_dims_of_2) { test.Run(); } +TEST(GatherNDOpTest, GatherND_slice_int64_t) { + OpTester test("GatherND", 12, kOnnxDomain); + std::vector data({0LL, 1LL, 2LL, 3LL}); + std::vector outputs({2LL, 3LL, 0LL, 1LL}); + test.AddInput("data", {2, 2}, data); + test.AddInput("indices", {2, 1}, {1LL, 0LL}); + test.AddOutput("output", {2, 2}, outputs); + test.Run(); +} + } // namespace test } // namespace onnxruntime