mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
int64_t support for GatherND cuda (#4881)
Co-authored-by: Vincent Wang <weicwang@microsoft.com>
This commit is contained in:
parent
acbf6d15c6
commit
fdd0926d00
3 changed files with 19 additions and 2 deletions
|
|
@ -98,7 +98,13 @@ Status GatherNDBase::PrepareCompute(
|
|||
TIndex, \
|
||||
kCudaExecutionProvider, \
|
||||
KernelDefBuilder() \
|
||||
.TypeConstraint("T", DataTypeImpl::AllIEEEFloatTensorTypes()) \
|
||||
.TypeConstraint("T", \
|
||||
std::vector<MLDataType>{ \
|
||||
DataTypeImpl::GetTensorType<float>(), \
|
||||
DataTypeImpl::GetTensorType<double>(), \
|
||||
DataTypeImpl::GetTensorType<MLFloat16>(), \
|
||||
DataTypeImpl::GetTensorType<int64_t>(), \
|
||||
}) \
|
||||
.TypeConstraint("Tind", DataTypeImpl::GetTensorType<TIndex>()), \
|
||||
GatherND<TIndex>);
|
||||
|
||||
|
|
@ -161,7 +167,7 @@ Status GatherND<TIndex>::ComputeInternal(OpKernelContext* context) const {
|
|||
|
||||
const void* const kernel_input_data = input_tensor->DataRaw();
|
||||
void* const kernel_output_data = output_tensor->MutableDataRaw();
|
||||
utils::MLTypeCallDispatcher<GatherNDComputeImpl, float, MLFloat16, double>
|
||||
utils::MLTypeCallDispatcher<GatherNDComputeImpl, float, MLFloat16, double, int64_t>
|
||||
t_disp(input_tensor->GetElementType());
|
||||
t_disp.Invoke(num_slices, slice_size, kernel_input_data, kernel_output_data, input_slice_offsets_buffer.get());
|
||||
|
||||
|
|
|
|||
|
|
@ -105,6 +105,7 @@ SPECIALIZED_COMPUTE_SLICE_OFFSETS_IMPL(int32_t)
|
|||
SPECIALIZED_COMPUTE_SLICE_OFFSETS_IMPL(int64_t)
|
||||
|
||||
SPECIALIZED_IMPL(float)
|
||||
SPECIALIZED_IMPL(int64_t)
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600
|
||||
SPECIALIZED_IMPL(half)
|
||||
SPECIALIZED_IMPL(double)
|
||||
|
|
|
|||
|
|
@ -299,5 +299,15 @@ TEST(GatherNDOpTest, GatherND_batch_dims_of_2) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(GatherNDOpTest, GatherND_slice_int64_t) {
|
||||
OpTester test("GatherND", 12, kOnnxDomain);
|
||||
std::vector<int64_t> data({0LL, 1LL, 2LL, 3LL});
|
||||
std::vector<int64_t> outputs({2LL, 3LL, 0LL, 1LL});
|
||||
test.AddInput<int64_t>("data", {2, 2}, data);
|
||||
test.AddInput<int64_t>("indices", {2, 1}, {1LL, 0LL});
|
||||
test.AddOutput<int64_t>("output", {2, 2}, outputs);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue