int64_t support for GatherND cuda (#4881)

Co-authored-by: Vincent Wang <weicwang@microsoft.com>
This commit is contained in:
Vincent Wang 2020-08-22 08:00:31 +08:00 committed by GitHub
parent acbf6d15c6
commit fdd0926d00
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 19 additions and 2 deletions

View file

@ -98,7 +98,13 @@ Status GatherNDBase::PrepareCompute(
TIndex, \
kCudaExecutionProvider, \
KernelDefBuilder() \
.TypeConstraint("T", DataTypeImpl::AllIEEEFloatTensorTypes()) \
.TypeConstraint("T", \
std::vector<MLDataType>{ \
DataTypeImpl::GetTensorType<float>(), \
DataTypeImpl::GetTensorType<double>(), \
DataTypeImpl::GetTensorType<MLFloat16>(), \
DataTypeImpl::GetTensorType<int64_t>(), \
}) \
.TypeConstraint("Tind", DataTypeImpl::GetTensorType<TIndex>()), \
GatherND<TIndex>);
@ -161,7 +167,7 @@ Status GatherND<TIndex>::ComputeInternal(OpKernelContext* context) const {
const void* const kernel_input_data = input_tensor->DataRaw();
void* const kernel_output_data = output_tensor->MutableDataRaw();
utils::MLTypeCallDispatcher<GatherNDComputeImpl, float, MLFloat16, double>
utils::MLTypeCallDispatcher<GatherNDComputeImpl, float, MLFloat16, double, int64_t>
t_disp(input_tensor->GetElementType());
t_disp.Invoke(num_slices, slice_size, kernel_input_data, kernel_output_data, input_slice_offsets_buffer.get());

View file

@ -105,6 +105,7 @@ SPECIALIZED_COMPUTE_SLICE_OFFSETS_IMPL(int32_t)
SPECIALIZED_COMPUTE_SLICE_OFFSETS_IMPL(int64_t)
SPECIALIZED_IMPL(float)
SPECIALIZED_IMPL(int64_t)
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600
SPECIALIZED_IMPL(half)
SPECIALIZED_IMPL(double)

View file

@ -299,5 +299,15 @@ TEST(GatherNDOpTest, GatherND_batch_dims_of_2) {
test.Run();
}
TEST(GatherNDOpTest, GatherND_slice_int64_t) {
OpTester test("GatherND", 12, kOnnxDomain);
std::vector<int64_t> data({0LL, 1LL, 2LL, 3LL});
std::vector<int64_t> outputs({2LL, 3LL, 0LL, 1LL});
test.AddInput<int64_t>("data", {2, 2}, data);
test.AddInput<int64_t>("indices", {2, 1}, {1LL, 0LL});
test.AddOutput<int64_t>("output", {2, 2}, outputs);
test.Run();
}
} // namespace test
} // namespace onnxruntime