Refine GatherND CPU/CUDA Kernels & Add UTs (#3688)

* Refactor GatherND CPU Kernel (Renaming & Simplify)

* Add batch_dim=1 or 2, negative slices tests

* Rename gather_nd_gard_impl.cu

* Use dispatcher to refactor CUDA GatherND/GatherNDGrad

* Change GatherNDBase::CommonComputeKernel --> GatherNDBase::PrepareCompute

* Use HasCudaEnvironment instead of __CUDA_ARCH__ for some double type tests
This commit is contained in:
pengwa 2020-04-30 10:17:54 +08:00 committed by GitHub
parent 58f53966d3
commit 0531acccc5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 242 additions and 149 deletions

View file

@ -26,7 +26,7 @@ ONNX_OPERATOR_KERNEL_EX(GatherND, kMSDomain, 1, kCpuExecutionProvider,
#endif
ONNX_CPU_OPERATOR_KERNEL(
GatherND,
GatherND,
11,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
@ -46,7 +46,7 @@ template <typename Tind>
Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) const {
const auto* input_tensor = context->Input<Tensor>(0);
const auto* indices_tensor = context->Input<Tensor>(1);
ORT_ENFORCE(input_tensor != nullptr && indices_tensor != nullptr, "GatherND op: Input count mismatch");
ORT_ENFORCE(input_tensor != nullptr && indices_tensor != nullptr, "GatherNDBase PrepareForCompute: Input count mismatch");
const auto& input_shape = input_tensor->Shape();
const auto& indices_shape = indices_tensor->Shape();
@ -54,7 +54,14 @@ Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) con
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "indices tensor must has rank larger than 0");
}
int64_t last_indices_dimension = indices_shape[indices_shape.NumDimensions() - 1] + batch_dims_;
const auto num_slice_dims = indices_shape[indices_shape.NumDimensions() - 1];
const auto num_slices = indices_shape.SizeToDimension(indices_shape.NumDimensions() - 1);
const auto slice_size = input_shape.SizeFromDimension(batch_dims_ + num_slice_dims);
const auto num_batches = input_shape.SizeToDimension(batch_dims_);
const auto input_batch_stride = input_shape.SizeFromDimension(batch_dims_);
const auto num_slices_per_batch = num_slices / num_batches;
int64_t last_indices_dimension = batch_dims_ + num_slice_dims;
if (last_indices_dimension > static_cast<int64_t>(input_shape.NumDimensions())) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"last dimension of indices must not be larger than rank of input tensor");
@ -63,31 +70,21 @@ Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) con
std::vector<int64_t> shape(indices_shape.GetDims().begin(), indices_shape.GetDims().end() - 1);
shape.insert(shape.end(), input_shape.GetDims().begin() + last_indices_dimension, input_shape.GetDims().end());
auto* output_tensor = context->Output(0, TensorShape(std::move(shape)));
std::vector<int64_t> element_counts(last_indices_dimension + batch_dims_,
0LL); // Number of elements for each input dimension
std::vector<int64_t> sizes_from_slice_dims(num_slice_dims);
#ifdef _OPENMP
#pragma omp parallel for
#endif
for (int64_t i = 0; i < last_indices_dimension; ++i) {
element_counts[i] = input_shape.SizeFromDimension(i + 1);
}
auto last_dim_size = indices_shape.SizeFromDimension(indices_shape.NumDimensions() - 1);
#ifdef USE_OPENMP
#pragma omp parallel for
#endif
for (int64_t i = batch_dims_ - 1; i >= 0; --i) {
element_counts[last_indices_dimension + i] = indices_shape.SizeFromDimension(i + 1) / last_dim_size;
for (int64_t i = 0; i < num_slice_dims; ++i) {
sizes_from_slice_dims[i] = input_shape.SizeFromDimension(batch_dims_ + i + 1);
}
int64_t err_index = 0;
p.element_bytes = input_tensor->DataType()->Size();
p.element_to_copy = input_shape.SizeFromDimension(last_indices_dimension);
p.bytes_to_copy = p.element_bytes * p.element_to_copy;
const auto* indice_offset = indices_tensor->Data<Tind>();
const int64_t offset_count = indices_shape.Size() / (last_indices_dimension - batch_dims_); // Times to copy
p.element_offsets.assign(offset_count, 0LL);
p.element_count_per_slice = slice_size;
p.bytes_per_slice = p.element_bytes * p.element_count_per_slice;
const auto* indices_data = indices_tensor->Data<Tind>();
p.slice_offsets.assign(num_slices, 0LL);
if (input_tensor->IsDataTypeString()) {
p.input_str_base = static_cast<const std::string*>(input_tensor->DataRaw());
@ -97,29 +94,30 @@ Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) con
p.output_base = static_cast<uint8_t*>(output_tensor->MutableDataRaw());
}
//Compute the element_offset
// Compute the element_offset
#ifdef _OPENMP
#pragma omp parallel for
#endif
for (int64_t i = 0; i < offset_count; ++i) {
int64_t reminder = i;
for (int64_t j = 0; j < batch_dims_; ++j) {
int64_t idx = reminder / element_counts[last_indices_dimension + j];
p.element_offsets[i] += idx * element_counts[j];
reminder -= (idx * element_counts[last_indices_dimension + j]);
}
for (int64_t j = batch_dims_; j < last_indices_dimension; ++j) {
auto index = *(indice_offset + i * (last_indices_dimension - batch_dims_) + (j - batch_dims_));
auto upper_limit = input_shape[j];
auto lower_limit = -upper_limit;
for (int64_t slice_idx = 0; slice_idx < num_slices; ++slice_idx) {
const size_t batch_idx = slice_idx / num_slices_per_batch;
const size_t input_base_offset = batch_idx * input_batch_stride;
const auto* const slice_indices = indices_data + slice_idx * num_slice_dims;
size_t relative_slice_offset = 0;
for (int64_t dim_idx = 0; dim_idx < num_slice_dims; ++dim_idx) {
int64_t index = static_cast<int64_t>(slice_indices[dim_idx]);
const auto upper_limit = input_shape[batch_dims_ + dim_idx];
const auto lower_limit = -upper_limit;
if (index < lower_limit || index >= upper_limit) {
err_index = index;
break;
}
if (index < 0) {
index += static_cast<Tind>(upper_limit);
}
p.element_offsets[i] += index * element_counts[j];
if (index < 0) index += upper_limit;
relative_slice_offset += index * sizes_from_slice_dims[dim_idx];
}
p.slice_offsets[slice_idx] = input_base_offset + relative_slice_offset;
}
return err_index == 0 ? Status::OK()
@ -142,9 +140,9 @@ Status GatherND::GatherNumber(const Prepare& p) const {
#ifdef _OPENMP
#pragma omp parallel for
#endif
for (int64_t i = 0; i < static_cast<int64_t>(p.element_offsets.size()); ++i) {
memcpy(p.output_base + i * p.bytes_to_copy, p.input_base + p.element_offsets[i] * p.element_bytes,
p.bytes_to_copy);
for (int64_t slice_idx = 0; slice_idx < static_cast<int64_t>(p.slice_offsets.size()); ++slice_idx) {
memcpy(p.output_base + slice_idx * p.bytes_per_slice, p.input_base + p.slice_offsets[slice_idx] * p.element_bytes,
p.bytes_per_slice);
}
return Status::OK();
@ -154,9 +152,10 @@ Status GatherND::GatherString(const Prepare& p) const {
#ifdef _OPENMP
#pragma omp parallel for
#endif
for (int64_t i = 0; i < static_cast<int64_t>(p.element_offsets.size()); ++i) {
for (int64_t j = 0; j < static_cast<int64_t>(p.element_to_copy); ++j) {
p.output_str_base[i * p.element_to_copy + j] = p.input_str_base[p.element_offsets[i] + j];
for (int64_t slice_idx = 0; slice_idx < static_cast<int64_t>(p.slice_offsets.size()); ++slice_idx) {
const int64_t slice_base_offset = slice_idx * p.element_count_per_slice;
for (int64_t j = 0; j < static_cast<int64_t>(p.element_count_per_slice); ++j) {
p.output_str_base[slice_base_offset + j] = p.input_str_base[p.slice_offsets[slice_idx] + j];
}
}

View file

@ -16,19 +16,19 @@ class GatherNDBase {
const std::string* input_str_base;
uint8_t* output_base;
std::string* output_str_base;
uint64_t bytes_to_copy;
uint64_t bytes_per_slice;
uint64_t element_bytes;
uint64_t element_to_copy;
std::vector<uint64_t> element_offsets;
uint64_t element_count_per_slice;
std::vector<uint64_t> slice_offsets;
Prepare() : input_base(nullptr),
input_str_base(nullptr),
output_base(nullptr),
output_str_base(nullptr),
bytes_to_copy(0),
bytes_per_slice(0),
element_bytes(0),
element_to_copy(0),
element_offsets(0) {}
element_count_per_slice(0),
slice_offsets(0) {}
}; // struct Prepare
template <typename Tind>

View file

@ -38,40 +38,23 @@ Status CheckBatchDimensionsMatch(
return Status::OK();
}
#define TYPED_FUNCTION_CALL_FWD(T) \
if (T_type == DataTypeImpl::GetType<T>()) { \
GatherNDImpl<ToCudaType<T>::MappedType>(num_slices, kernel_input_data, kernel_output_data, slice_size, input_slice_offsets_buffer.get()); \
}
#define TYPED_FUNCTION_CALL_BWD(T) \
if (T_type == DataTypeImpl::GetType<T>()) { \
GatherNDGradImpl<ToCudaType<T>::MappedType>(num_slices, kernel_input_data, kernel_output_data, slice_size, input_slice_offsets_buffer.get()); \
}
template <typename TIndex>
Status GatherNDBase::CommonComputeKernel(
Status GatherNDBase::PrepareCompute(
const int64_t batch_dims,
const TensorShape& input_shape,
const Tensor* kernel_input_tensor,
Tensor* kernel_output_tensor,
const TensorShape& indices_shape,
const Tensor* indices_tensor,
const bool fwd) const {
// Note on naming:
// `input` refers to the GatherND `data` input, while `kernel_input` refers to
// what the GatherND[Grad] CUDA kernel accepts as input.
int64_t& num_slices,
int64_t& slice_size,
IAllocatorUniquePtr<int64_t>& input_slice_offsets_buffer) const {
const auto num_slice_dims = indices_shape[indices_shape.NumDimensions() - 1];
const auto num_slices = indices_shape.SizeToDimension(indices_shape.NumDimensions() - 1);
const auto slice_size = input_shape.SizeFromDimension(batch_dims + num_slice_dims);
num_slices = indices_shape.SizeToDimension(indices_shape.NumDimensions() - 1);
slice_size = input_shape.SizeFromDimension(batch_dims + num_slice_dims);
const auto num_batches = input_shape.SizeToDimension(batch_dims);
const auto input_batch_stride = input_shape.SizeFromDimension(batch_dims);
const auto num_slices_per_batch = num_slices / num_batches;
const TIndex* const indices_data = indices_tensor->Data<TIndex>();
const void* const kernel_input_data = kernel_input_tensor->DataRaw();
void* const kernel_output_data = kernel_output_tensor->MutableDataRaw();
std::vector<int64_t> sizes_from_slice_dims(num_slice_dims);
{
@ -89,10 +72,10 @@ Status GatherNDBase::CommonComputeKernel(
sizes_from_slice_dims.size() * sizeof(int64_t),
cudaMemcpyHostToDevice));
auto input_slice_offsets_buffer = GetScratchBuffer<int64_t>(num_slices);
input_slice_offsets_buffer = GetScratchBuffer<int64_t>(num_slices);
TArray<int64_t> input_dims(input_shape.GetDims());
// TODO reuse computed slice offsets from GatherND in GatherNDGrad
ComputeSliceOffsetsImpl(
batch_dims,
input_dims,
@ -104,39 +87,19 @@ Status GatherNDBase::CommonComputeKernel(
indices_data,
input_slice_offsets_buffer.get());
if (fwd) {
MLDataType T_type = kernel_input_tensor->DataType();
TYPED_FUNCTION_CALL_FWD(float);
TYPED_FUNCTION_CALL_FWD(MLFloat16);
TYPED_FUNCTION_CALL_FWD(double);
} else {
#ifdef ENABLE_TRAINING
MLDataType T_type = kernel_input_tensor->DataType();
TYPED_FUNCTION_CALL_BWD(float);
TYPED_FUNCTION_CALL_BWD(MLFloat16);
TYPED_FUNCTION_CALL_BWD(double);
#else
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Gradient computation is only supported in the training mode.");
#endif
}
return Status::OK();
}
#undef TYPED_FUNCTION_CALL_FWD
#undef TYPED_FUNCTION_CALL_BWD
#define REGISTER_KERNEL_TYPED_GATHER_ND(TIndex, ver) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
GatherND, \
kOnnxDomain, \
ver, \
TIndex, \
kCudaExecutionProvider, \
KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType<MLFloat16>(), \
DataTypeImpl::GetTensorType<float>(), DataTypeImpl::GetTensorType<double>()}) \
.TypeConstraint("Tind", DataTypeImpl::GetTensorType<TIndex>()), \
#define REGISTER_KERNEL_TYPED_GATHER_ND(TIndex, ver) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
GatherND, \
kOnnxDomain, \
ver, \
TIndex, \
kCudaExecutionProvider, \
KernelDefBuilder() \
.TypeConstraint("T", DataTypeImpl::AllIEEEFloatTensorTypes()) \
.TypeConstraint("Tind", DataTypeImpl::GetTensorType<TIndex>()), \
GatherND<TIndex>);
// TODO: decprecate GatherND-1 after updating training models to opset-12
@ -145,6 +108,20 @@ REGISTER_KERNEL_TYPED_GATHER_ND(int64_t, 1)
#endif
REGISTER_KERNEL_TYPED_GATHER_ND(int64_t, 12)
template <typename T>
struct GatherNDComputeImpl {
void operator()(const int64_t num_slices,
const int64_t slice_size,
const void* const kernel_input_data,
void* const kernel_output_data,
int64_t* const input_slice_offsets_data) const {
typedef typename ToCudaType<T>::MappedType CudaT;
GatherNDImpl<CudaT>(num_slices, kernel_input_data,
kernel_output_data, slice_size,
input_slice_offsets_data);
}
};
template <typename TIndex>
Status GatherND<TIndex>::ComputeInternal(OpKernelContext* context) const {
auto input_tensor = context->Input<Tensor>(0);
@ -169,16 +146,26 @@ Status GatherND<TIndex>::ComputeInternal(OpKernelContext* context) const {
ORT_RETURN_IF_ERROR(CheckBatchDimensionsMatch(
static_cast<size_t>(batch_dims_), {input_shape, indices_shape}));
//Output shape
// Output shape
std::vector<int64_t> shape(indices_shape.GetDims().begin(), indices_shape.GetDims().end() - 1);
shape.insert(shape.end(), input_shape.GetDims().begin() + last_indices_dimension, input_shape.GetDims().end());
auto output_tensor = context->Output(0, TensorShape(shape));
//Compute
auto status = CommonComputeKernel<TIndex>(batch_dims_, input_shape, input_tensor, output_tensor, indices_shape, indices_tensor, true);
// Compute
int64_t num_slices;
int64_t slice_size;
IAllocatorUniquePtr<int64_t> input_slice_offsets_buffer;
ORT_RETURN_IF_ERROR(PrepareCompute<TIndex>(batch_dims_, input_shape, indices_shape, indices_tensor,
num_slices, slice_size, input_slice_offsets_buffer));
return status;
const void* const kernel_input_data = input_tensor->DataRaw();
void* const kernel_output_data = output_tensor->MutableDataRaw();
utils::MLTypeCallDispatcher<GatherNDComputeImpl, float, MLFloat16, double>
t_disp(input_tensor->GetElementType());
t_disp.Invoke(num_slices, slice_size, kernel_input_data, kernel_output_data, input_slice_offsets_buffer.get());
return Status::OK();
}
} // namespace cuda

View file

@ -23,14 +23,14 @@ class GatherNDBase : public CudaKernel {
protected:
template <typename TIndex>
Status CommonComputeKernel(
Status PrepareCompute(
const int64_t batch_dims,
const TensorShape& input_shape,
const Tensor* input_tensor,
Tensor* output_tensor,
const TensorShape& indices_shape,
const Tensor* indices_tensor,
const bool fwd) const;
int64_t& num_slices,
int64_t& slice_size,
IAllocatorUniquePtr<int64_t>& input_slice_offsets_buffer) const;
int64_t batch_dims_;
};

View file

@ -2,7 +2,6 @@
// Licensed under the MIT License.
#include "core/providers/cuda/tensor/gather_nd_impl.h"
#include "core/providers/cuda/cu_inc/common.cuh"
#include "core/providers/cuda/atomic/common.cuh"

View file

@ -37,5 +37,12 @@ inline bool HasCudaEnvironment(int min_cuda_architecture) {
return cuda_architecture >= min_cuda_architecture;
}
inline bool NeedSkipIfCudaArchLowerThan(int min_cuda_architecture) {
// only skip when CUDA ep is enabled.
if (DefaultCudaExecutionProvider().get() != nullptr) {
return !HasCudaEnvironment(min_cuda_architecture);
}
return false;
}
} // namespace test
} // namespace onnxruntime

View file

@ -79,7 +79,9 @@ TEST(GatherNDOpTest, int64_t) {
}
TEST(GatherNDOpTest, float) {
if (!HasCudaEnvironment(600 /*min_cuda_architecture*/)) return;
if (NeedSkipIfCudaArchLowerThan(600)) {
return;
}
RunTest<float>({2, 2}, {0.0f, 0.1f, 0.2f, 0.3f}, {2, 1}, {1LL, 0LL}, {2, 2}, {0.2f, 0.3f, 0.0f, 0.1f});
@ -88,7 +90,9 @@ TEST(GatherNDOpTest, float) {
}
TEST(GatherNDOpTest, double) {
if (!HasCudaEnvironment(600 /*min_cuda_architecture*/)) return;
if (NeedSkipIfCudaArchLowerThan(600)) {
return;
}
RunTest<double>({2, 2}, {0.0, 0.1, 0.2, 0.3}, {2, 1}, {1LL, 0LL}, {2, 2}, {0.2, 0.3, 0.0, 0.1});
@ -133,7 +137,15 @@ TEST(GatherNDOpTest, ContribOpInt32Indices) {
#endif
TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_0) {
TEST(GatherNDOpTest, GatherND_slice_float_default_batch_dims) {
OpTester test("GatherND", 12, kOnnxDomain);
test.AddInput<float>("data", {2, 3, 4}, ValueRange(24, 1.0f));
test.AddInput<int64_t>("indices", {3, 2, 2}, {0LL, 1LL, 0LL, 2LL, 1LL, 0LL, 0LL, 0LL, 1LL, 1LL, 1LL, 2LL});
test.AddOutput<float>("output", {3, 2, 4}, {5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 1.0, 2.0, 3.0, 4.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0});
test.Run();
}
TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_zero) {
OpTester test("GatherND", 12, kOnnxDomain);
test.AddAttribute<int64_t>("batch_dims", 0);
test.AddInput<float>("data", {2, 3, 4}, ValueRange(24, 1.0f));
@ -142,7 +154,7 @@ TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_0) {
test.Run();
}
TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_1) {
TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_one_1) {
OpTester test("GatherND", 12, kOnnxDomain);
test.AddAttribute<int64_t>("batch_dims", 1);
test.AddInput<float>("data", {2, 3, 4}, ValueRange(24, 1.0f));
@ -151,7 +163,7 @@ TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_1) {
test.Run();
}
TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_2) {
TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_one_2) {
OpTester test("GatherND", 12, kOnnxDomain);
test.AddAttribute<int64_t>("batch_dims", 1);
test.AddInput<float>("data", {2, 2, 2}, ValueRange(8, 0.0f, 0.1f));
@ -160,28 +172,7 @@ TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_2) {
test.Run();
}
#ifdef USE_CUDA
#if __CUDA_ARCH__ >= 600
TEST(GatherNDOpTest, GatherND_slice_double_batch_dims_3) {
OpTester test("GatherND", 12, kOnnxDomain);
test.AddAttribute<int64_t>("batch_dims", 1);
test.AddInput<double>("data", {2, 2, 2}, ValueRange(8, 0.0f, 0.1f));
test.AddInput<int64_t>("indices", {2, 1, 1}, {1LL, 0LL});
test.AddOutput<double>("output", {2, 1, 2}, {0.2f, 0.3f, 0.4f, 0.5f});
test.Run();
}
TEST(GatherNDOpTest, GatherND_slice_double) {
OpTester test("GatherND", 12, kOnnxDomain);
test.AddInput<double>("data", {2, 2}, {0.0f, 0.1f, 0.2f, 0.3f});
test.AddInput<int64_t>("indices", {2, 1}, {1LL, 0LL});
test.AddOutput<double>("output", {2, 2}, {0.2f, 0.3f, 0.0f, 0.1f});
test.Run();
}
#endif
#endif
TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_4) {
TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_one_3) {
OpTester test("GatherND", 12, kOnnxDomain);
test.AddAttribute<int64_t>("batch_dims", 1);
test.AddInput<float>("data", {2, 2, 2}, ValueRange(8, 0.0f, 0.1f));
@ -190,10 +181,74 @@ TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_4) {
test.Run();
}
#ifdef USE_CUDA
TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_two) {
OpTester test("GatherND", 12, kOnnxDomain);
test.AddAttribute<int64_t>("batch_dims", 2);
test.AddInput<float>("data", {2, 1, 3, 5}, ValueRange(30, 0.0f, 0.1f));
test.AddInput<int64_t>("indices", {2, 1, 3, 2},
{0LL, 0LL,
0LL, 1LL,
1LL, 0LL,
1LL, 1LL,
0LL, 4LL,
2LL, 4LL});
test.AddOutput<float>("output", {2, 1, 3}, {0.0f, 0.1f, 0.5f, 2.1f, 1.9f, 2.9f});
test.Run();
}
TEST(GatherNDOpTest, GatherND_slice_double_batch_dims_3) {
if (!HasCudaEnvironment(600 /*min_cuda_architecture*/)) return;
TEST(GatherNDOpTest, GatherND_negative_slice_float_batch_dims_one) {
OpTester test("GatherND", 12, kOnnxDomain);
test.AddAttribute<int64_t>("batch_dims", 1);
test.AddInput<float>("data", {2, 3, 4}, ValueRange(24, 1.0f));
test.AddInput<int64_t>("indices", {2, 2, 2}, {0LL, -3LL, -1LL, 2LL, -1LL, 0LL, 0LL, -2LL});
test.AddOutput<float>("output", {2, 2}, {2.0, 11.0, 21.0, 15.0});
test.Run();
}
TEST(GatherNDOpTest, GatherND_negative_slice_float_batch_dims_two) {
OpTester test("GatherND", 12, kOnnxDomain);
test.AddAttribute<int64_t>("batch_dims", 2);
test.AddInput<float>("data", {2, 1, 3, 5}, ValueRange(30, 0.0f, 0.1f));
test.AddInput<int64_t>("indices", {2, 1, 3, 2},
{0LL, -5LL,
-3LL, 1LL,
-2LL, 0LL,
-2LL, -4LL,
0LL, -1LL,
2LL, -1LL});
test.AddOutput<float>("output", {2, 1, 3}, {0.0f, 0.1f, 0.5f, 2.1f, 1.9f, 2.9f});
test.Run();
}
TEST(GatherNDOpTest, GatherND_slice_double_batch_dims_one_1) {
if (NeedSkipIfCudaArchLowerThan(600)) {
return;
}
OpTester test("GatherND", 12, kOnnxDomain);
test.AddAttribute<int64_t>("batch_dims", 1);
test.AddInput<double>("data", {2, 2, 2}, ValueRange(8, 0.0, 0.1));
test.AddInput<int64_t>("indices", {2, 1, 1}, {1LL, 0LL});
test.AddOutput<double>("output", {2, 1, 2}, {0.2f, 0.3f, 0.4f, 0.5f});
test.Run();
}
TEST(GatherNDOpTest, GatherND_slice_double_default_batch_dims) {
if (NeedSkipIfCudaArchLowerThan(600)) {
return;
}
OpTester test("GatherND", 12, kOnnxDomain);
test.AddInput<double>("data", {2, 2}, {0.0f, 0.1f, 0.2f, 0.3f});
test.AddInput<int64_t>("indices", {2, 1}, {1LL, 0LL});
test.AddOutput<double>("output", {2, 2}, {0.2f, 0.3f, 0.0f, 0.1f});
test.Run();
} // namespace test
TEST(GatherNDOpTest, GatherND_slice_double_batch_dims_one_2) {
if (NeedSkipIfCudaArchLowerThan(600)) {
return;
}
OpTester test("GatherND", 12, kOnnxDomain);
test.AddAttribute<int64_t>("batch_dims", 1);
@ -204,7 +259,9 @@ TEST(GatherNDOpTest, GatherND_slice_double_batch_dims_3) {
}
TEST(GatherNDOpTest, GatherND_slice_half) {
if (!HasCudaEnvironment(600 /*min_cuda_architecture*/)) return;
if (NeedSkipIfCudaArchLowerThan(600)) {
return;
}
OpTester test("GatherND", 12, kOnnxDomain);
std::vector<float> data_f({0.0f, 0.1f, 0.2f, 0.3f});
@ -242,7 +299,5 @@ TEST(GatherNDOpTest, GatherND_batch_dims_of_2) {
test.Run();
}
#endif
} // namespace test
} // namespace onnxruntime

View file

@ -2,6 +2,7 @@
// Licensed under the MIT License.
#include "orttraining/training_ops/cuda/tensor/gather_nd_grad.h"
#include "orttraining/training_ops/cuda/tensor/gather_nd_grad_impl.h"
#include "core/providers/cuda/shared_inc/cuda_utils.h"
namespace onnxruntime {
@ -23,6 +24,20 @@ namespace cuda {
REGISTER_KERNEL_TYPED_GATHER_ND_GRAD(int64_t)
template <typename T>
struct GatherNDGradComputeImpl {
void operator()(const int64_t num_slices,
const int64_t slice_size,
const void* const kernel_input_data,
void* const kernel_output_data,
int64_t* const input_slice_offsets_data) const {
typedef typename ToCudaType<T>::MappedType CudaT;
GatherNDGradImpl<CudaT>(num_slices, kernel_input_data,
kernel_output_data, slice_size,
input_slice_offsets_data);
}
};
template <typename TIndex>
Status GatherNDGrad<TIndex>::ComputeInternal(OpKernelContext* context) const {
auto shape_tensor = context->Input<Tensor>(0);
@ -42,7 +57,7 @@ Status GatherNDGrad<TIndex>::ComputeInternal(OpKernelContext* context) const {
auto last_indices_dimension = batch_dims_ + indices_shape[indices_shape.NumDimensions() - 1];
//Output
// Output
auto shape_data = shape_tensor->Data<int64_t>();
auto input_shape = TensorShape(shape_data, shape_tensor->SizeInBytes() / sizeof(shape_tensor->DataType()));
@ -59,8 +74,20 @@ Status GatherNDGrad<TIndex>::ComputeInternal(OpKernelContext* context) const {
// TODO this memset can be expensive, a sparse tensor representation would help here
CUDA_RETURN_IF_ERROR(cudaMemsetAsync(output_tensor->MutableDataRaw(), 0, output_tensor->SizeInBytes()));
auto status = CommonComputeKernel<TIndex>(batch_dims_, input_shape, update_tensor, output_tensor, indices_shape, indices_tensor, false);
return status;
// Compute
int64_t num_slices;
int64_t slice_size;
IAllocatorUniquePtr<int64_t> input_slice_offsets_buffer;
ORT_RETURN_IF_ERROR(PrepareCompute<TIndex>(batch_dims_, input_shape, indices_shape, indices_tensor,
num_slices, slice_size, input_slice_offsets_buffer));
const void* const kernel_input_data = update_tensor->DataRaw();
void* const kernel_output_data = output_tensor->MutableDataRaw();
utils::MLTypeCallDispatcher<GatherNDGradComputeImpl, float, MLFloat16, double>
t_disp(update_tensor->GetElementType());
t_disp.Invoke(num_slices, slice_size, kernel_input_data, kernel_output_data, input_slice_offsets_buffer.get());
return Status::OK();
}
} // namespace cuda

View file

@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/cuda/tensor/gather_nd_impl.h"
#include "orttraining/training_ops/cuda/tensor/gather_nd_grad_impl.h"
#include "core/providers/cuda/cu_inc/common.cuh"
#include "core/providers/cuda/atomic/common.cuh"

View file

@ -0,0 +1,19 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/cuda/shared_inc/cuda_utils.h"
namespace onnxruntime {
namespace cuda {
template <typename T>
void GatherNDGradImpl(
const size_t num_slices,
const void* update_data,
void* output_data,
const size_t slice_size,
const int64_t* input_slice_offsets_data);
} // namespace cuda
} // namespace onnxruntime