mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
Refine GatherND CPU/CUDA Kernels & Add UTs (#3688)
* Refactor GatherND CPU Kernel (Renaming & Simplify) * Add batch_dim=1 or 2, negative slices tests * Rename gather_nd_gard_impl.cu * Use dispatcher to refactor CUDA GatherND/GatherNDGrad * Change GatherNDBase::CommonComputeKernel --> GatherNDBase::PrepareCompute * Use HasCudaEnvironment instead of __CUDA_ARCH__ for some double type tests
This commit is contained in:
parent
58f53966d3
commit
0531acccc5
10 changed files with 242 additions and 149 deletions
|
|
@ -26,7 +26,7 @@ ONNX_OPERATOR_KERNEL_EX(GatherND, kMSDomain, 1, kCpuExecutionProvider,
|
|||
#endif
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
GatherND,
|
||||
GatherND,
|
||||
11,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
|
||||
|
|
@ -46,7 +46,7 @@ template <typename Tind>
|
|||
Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) const {
|
||||
const auto* input_tensor = context->Input<Tensor>(0);
|
||||
const auto* indices_tensor = context->Input<Tensor>(1);
|
||||
ORT_ENFORCE(input_tensor != nullptr && indices_tensor != nullptr, "GatherND op: Input count mismatch");
|
||||
ORT_ENFORCE(input_tensor != nullptr && indices_tensor != nullptr, "GatherNDBase PrepareForCompute: Input count mismatch");
|
||||
|
||||
const auto& input_shape = input_tensor->Shape();
|
||||
const auto& indices_shape = indices_tensor->Shape();
|
||||
|
|
@ -54,7 +54,14 @@ Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) con
|
|||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "indices tensor must has rank larger than 0");
|
||||
}
|
||||
|
||||
int64_t last_indices_dimension = indices_shape[indices_shape.NumDimensions() - 1] + batch_dims_;
|
||||
const auto num_slice_dims = indices_shape[indices_shape.NumDimensions() - 1];
|
||||
const auto num_slices = indices_shape.SizeToDimension(indices_shape.NumDimensions() - 1);
|
||||
const auto slice_size = input_shape.SizeFromDimension(batch_dims_ + num_slice_dims);
|
||||
const auto num_batches = input_shape.SizeToDimension(batch_dims_);
|
||||
const auto input_batch_stride = input_shape.SizeFromDimension(batch_dims_);
|
||||
const auto num_slices_per_batch = num_slices / num_batches;
|
||||
|
||||
int64_t last_indices_dimension = batch_dims_ + num_slice_dims;
|
||||
if (last_indices_dimension > static_cast<int64_t>(input_shape.NumDimensions())) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"last dimension of indices must not be larger than rank of input tensor");
|
||||
|
|
@ -63,31 +70,21 @@ Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) con
|
|||
std::vector<int64_t> shape(indices_shape.GetDims().begin(), indices_shape.GetDims().end() - 1);
|
||||
shape.insert(shape.end(), input_shape.GetDims().begin() + last_indices_dimension, input_shape.GetDims().end());
|
||||
auto* output_tensor = context->Output(0, TensorShape(std::move(shape)));
|
||||
std::vector<int64_t> element_counts(last_indices_dimension + batch_dims_,
|
||||
0LL); // Number of elements for each input dimension
|
||||
|
||||
std::vector<int64_t> sizes_from_slice_dims(num_slice_dims);
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int64_t i = 0; i < last_indices_dimension; ++i) {
|
||||
element_counts[i] = input_shape.SizeFromDimension(i + 1);
|
||||
}
|
||||
|
||||
auto last_dim_size = indices_shape.SizeFromDimension(indices_shape.NumDimensions() - 1);
|
||||
#ifdef USE_OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int64_t i = batch_dims_ - 1; i >= 0; --i) {
|
||||
element_counts[last_indices_dimension + i] = indices_shape.SizeFromDimension(i + 1) / last_dim_size;
|
||||
for (int64_t i = 0; i < num_slice_dims; ++i) {
|
||||
sizes_from_slice_dims[i] = input_shape.SizeFromDimension(batch_dims_ + i + 1);
|
||||
}
|
||||
|
||||
int64_t err_index = 0;
|
||||
p.element_bytes = input_tensor->DataType()->Size();
|
||||
p.element_to_copy = input_shape.SizeFromDimension(last_indices_dimension);
|
||||
p.bytes_to_copy = p.element_bytes * p.element_to_copy;
|
||||
const auto* indice_offset = indices_tensor->Data<Tind>();
|
||||
const int64_t offset_count = indices_shape.Size() / (last_indices_dimension - batch_dims_); // Times to copy
|
||||
p.element_offsets.assign(offset_count, 0LL);
|
||||
p.element_count_per_slice = slice_size;
|
||||
p.bytes_per_slice = p.element_bytes * p.element_count_per_slice;
|
||||
const auto* indices_data = indices_tensor->Data<Tind>();
|
||||
p.slice_offsets.assign(num_slices, 0LL);
|
||||
|
||||
if (input_tensor->IsDataTypeString()) {
|
||||
p.input_str_base = static_cast<const std::string*>(input_tensor->DataRaw());
|
||||
|
|
@ -97,29 +94,30 @@ Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) con
|
|||
p.output_base = static_cast<uint8_t*>(output_tensor->MutableDataRaw());
|
||||
}
|
||||
|
||||
//Compute the element_offset
|
||||
// Compute the element_offset
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int64_t i = 0; i < offset_count; ++i) {
|
||||
int64_t reminder = i;
|
||||
for (int64_t j = 0; j < batch_dims_; ++j) {
|
||||
int64_t idx = reminder / element_counts[last_indices_dimension + j];
|
||||
p.element_offsets[i] += idx * element_counts[j];
|
||||
reminder -= (idx * element_counts[last_indices_dimension + j]);
|
||||
}
|
||||
for (int64_t j = batch_dims_; j < last_indices_dimension; ++j) {
|
||||
auto index = *(indice_offset + i * (last_indices_dimension - batch_dims_) + (j - batch_dims_));
|
||||
auto upper_limit = input_shape[j];
|
||||
auto lower_limit = -upper_limit;
|
||||
for (int64_t slice_idx = 0; slice_idx < num_slices; ++slice_idx) {
|
||||
const size_t batch_idx = slice_idx / num_slices_per_batch;
|
||||
const size_t input_base_offset = batch_idx * input_batch_stride;
|
||||
|
||||
const auto* const slice_indices = indices_data + slice_idx * num_slice_dims;
|
||||
size_t relative_slice_offset = 0;
|
||||
for (int64_t dim_idx = 0; dim_idx < num_slice_dims; ++dim_idx) {
|
||||
int64_t index = static_cast<int64_t>(slice_indices[dim_idx]);
|
||||
const auto upper_limit = input_shape[batch_dims_ + dim_idx];
|
||||
const auto lower_limit = -upper_limit;
|
||||
if (index < lower_limit || index >= upper_limit) {
|
||||
err_index = index;
|
||||
break;
|
||||
}
|
||||
if (index < 0) {
|
||||
index += static_cast<Tind>(upper_limit);
|
||||
}
|
||||
p.element_offsets[i] += index * element_counts[j];
|
||||
if (index < 0) index += upper_limit;
|
||||
|
||||
relative_slice_offset += index * sizes_from_slice_dims[dim_idx];
|
||||
}
|
||||
|
||||
p.slice_offsets[slice_idx] = input_base_offset + relative_slice_offset;
|
||||
}
|
||||
|
||||
return err_index == 0 ? Status::OK()
|
||||
|
|
@ -142,9 +140,9 @@ Status GatherND::GatherNumber(const Prepare& p) const {
|
|||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int64_t i = 0; i < static_cast<int64_t>(p.element_offsets.size()); ++i) {
|
||||
memcpy(p.output_base + i * p.bytes_to_copy, p.input_base + p.element_offsets[i] * p.element_bytes,
|
||||
p.bytes_to_copy);
|
||||
for (int64_t slice_idx = 0; slice_idx < static_cast<int64_t>(p.slice_offsets.size()); ++slice_idx) {
|
||||
memcpy(p.output_base + slice_idx * p.bytes_per_slice, p.input_base + p.slice_offsets[slice_idx] * p.element_bytes,
|
||||
p.bytes_per_slice);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
@ -154,9 +152,10 @@ Status GatherND::GatherString(const Prepare& p) const {
|
|||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int64_t i = 0; i < static_cast<int64_t>(p.element_offsets.size()); ++i) {
|
||||
for (int64_t j = 0; j < static_cast<int64_t>(p.element_to_copy); ++j) {
|
||||
p.output_str_base[i * p.element_to_copy + j] = p.input_str_base[p.element_offsets[i] + j];
|
||||
for (int64_t slice_idx = 0; slice_idx < static_cast<int64_t>(p.slice_offsets.size()); ++slice_idx) {
|
||||
const int64_t slice_base_offset = slice_idx * p.element_count_per_slice;
|
||||
for (int64_t j = 0; j < static_cast<int64_t>(p.element_count_per_slice); ++j) {
|
||||
p.output_str_base[slice_base_offset + j] = p.input_str_base[p.slice_offsets[slice_idx] + j];
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -16,19 +16,19 @@ class GatherNDBase {
|
|||
const std::string* input_str_base;
|
||||
uint8_t* output_base;
|
||||
std::string* output_str_base;
|
||||
uint64_t bytes_to_copy;
|
||||
uint64_t bytes_per_slice;
|
||||
uint64_t element_bytes;
|
||||
uint64_t element_to_copy;
|
||||
std::vector<uint64_t> element_offsets;
|
||||
uint64_t element_count_per_slice;
|
||||
std::vector<uint64_t> slice_offsets;
|
||||
|
||||
Prepare() : input_base(nullptr),
|
||||
input_str_base(nullptr),
|
||||
output_base(nullptr),
|
||||
output_str_base(nullptr),
|
||||
bytes_to_copy(0),
|
||||
bytes_per_slice(0),
|
||||
element_bytes(0),
|
||||
element_to_copy(0),
|
||||
element_offsets(0) {}
|
||||
element_count_per_slice(0),
|
||||
slice_offsets(0) {}
|
||||
}; // struct Prepare
|
||||
|
||||
template <typename Tind>
|
||||
|
|
|
|||
|
|
@ -38,40 +38,23 @@ Status CheckBatchDimensionsMatch(
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
#define TYPED_FUNCTION_CALL_FWD(T) \
|
||||
if (T_type == DataTypeImpl::GetType<T>()) { \
|
||||
GatherNDImpl<ToCudaType<T>::MappedType>(num_slices, kernel_input_data, kernel_output_data, slice_size, input_slice_offsets_buffer.get()); \
|
||||
}
|
||||
|
||||
#define TYPED_FUNCTION_CALL_BWD(T) \
|
||||
if (T_type == DataTypeImpl::GetType<T>()) { \
|
||||
GatherNDGradImpl<ToCudaType<T>::MappedType>(num_slices, kernel_input_data, kernel_output_data, slice_size, input_slice_offsets_buffer.get()); \
|
||||
}
|
||||
|
||||
|
||||
template <typename TIndex>
|
||||
Status GatherNDBase::CommonComputeKernel(
|
||||
Status GatherNDBase::PrepareCompute(
|
||||
const int64_t batch_dims,
|
||||
const TensorShape& input_shape,
|
||||
const Tensor* kernel_input_tensor,
|
||||
Tensor* kernel_output_tensor,
|
||||
const TensorShape& indices_shape,
|
||||
const Tensor* indices_tensor,
|
||||
const bool fwd) const {
|
||||
// Note on naming:
|
||||
// `input` refers to the GatherND `data` input, while `kernel_input` refers to
|
||||
// what the GatherND[Grad] CUDA kernel accepts as input.
|
||||
|
||||
int64_t& num_slices,
|
||||
int64_t& slice_size,
|
||||
IAllocatorUniquePtr<int64_t>& input_slice_offsets_buffer) const {
|
||||
const auto num_slice_dims = indices_shape[indices_shape.NumDimensions() - 1];
|
||||
const auto num_slices = indices_shape.SizeToDimension(indices_shape.NumDimensions() - 1);
|
||||
const auto slice_size = input_shape.SizeFromDimension(batch_dims + num_slice_dims);
|
||||
num_slices = indices_shape.SizeToDimension(indices_shape.NumDimensions() - 1);
|
||||
slice_size = input_shape.SizeFromDimension(batch_dims + num_slice_dims);
|
||||
const auto num_batches = input_shape.SizeToDimension(batch_dims);
|
||||
const auto input_batch_stride = input_shape.SizeFromDimension(batch_dims);
|
||||
const auto num_slices_per_batch = num_slices / num_batches;
|
||||
|
||||
const TIndex* const indices_data = indices_tensor->Data<TIndex>();
|
||||
const void* const kernel_input_data = kernel_input_tensor->DataRaw();
|
||||
void* const kernel_output_data = kernel_output_tensor->MutableDataRaw();
|
||||
|
||||
std::vector<int64_t> sizes_from_slice_dims(num_slice_dims);
|
||||
{
|
||||
|
|
@ -89,10 +72,10 @@ Status GatherNDBase::CommonComputeKernel(
|
|||
sizes_from_slice_dims.size() * sizeof(int64_t),
|
||||
cudaMemcpyHostToDevice));
|
||||
|
||||
auto input_slice_offsets_buffer = GetScratchBuffer<int64_t>(num_slices);
|
||||
input_slice_offsets_buffer = GetScratchBuffer<int64_t>(num_slices);
|
||||
|
||||
TArray<int64_t> input_dims(input_shape.GetDims());
|
||||
// TODO reuse computed slice offsets from GatherND in GatherNDGrad
|
||||
|
||||
ComputeSliceOffsetsImpl(
|
||||
batch_dims,
|
||||
input_dims,
|
||||
|
|
@ -104,39 +87,19 @@ Status GatherNDBase::CommonComputeKernel(
|
|||
indices_data,
|
||||
input_slice_offsets_buffer.get());
|
||||
|
||||
if (fwd) {
|
||||
MLDataType T_type = kernel_input_tensor->DataType();
|
||||
TYPED_FUNCTION_CALL_FWD(float);
|
||||
TYPED_FUNCTION_CALL_FWD(MLFloat16);
|
||||
TYPED_FUNCTION_CALL_FWD(double);
|
||||
} else {
|
||||
#ifdef ENABLE_TRAINING
|
||||
MLDataType T_type = kernel_input_tensor->DataType();
|
||||
TYPED_FUNCTION_CALL_BWD(float);
|
||||
TYPED_FUNCTION_CALL_BWD(MLFloat16);
|
||||
TYPED_FUNCTION_CALL_BWD(double);
|
||||
#else
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Gradient computation is only supported in the training mode.");
|
||||
#endif
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#undef TYPED_FUNCTION_CALL_FWD
|
||||
#undef TYPED_FUNCTION_CALL_BWD
|
||||
|
||||
#define REGISTER_KERNEL_TYPED_GATHER_ND(TIndex, ver) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
GatherND, \
|
||||
kOnnxDomain, \
|
||||
ver, \
|
||||
TIndex, \
|
||||
kCudaExecutionProvider, \
|
||||
KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType<MLFloat16>(), \
|
||||
DataTypeImpl::GetTensorType<float>(), DataTypeImpl::GetTensorType<double>()}) \
|
||||
.TypeConstraint("Tind", DataTypeImpl::GetTensorType<TIndex>()), \
|
||||
#define REGISTER_KERNEL_TYPED_GATHER_ND(TIndex, ver) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
GatherND, \
|
||||
kOnnxDomain, \
|
||||
ver, \
|
||||
TIndex, \
|
||||
kCudaExecutionProvider, \
|
||||
KernelDefBuilder() \
|
||||
.TypeConstraint("T", DataTypeImpl::AllIEEEFloatTensorTypes()) \
|
||||
.TypeConstraint("Tind", DataTypeImpl::GetTensorType<TIndex>()), \
|
||||
GatherND<TIndex>);
|
||||
|
||||
// TODO: decprecate GatherND-1 after updating training models to opset-12
|
||||
|
|
@ -145,6 +108,20 @@ REGISTER_KERNEL_TYPED_GATHER_ND(int64_t, 1)
|
|||
#endif
|
||||
REGISTER_KERNEL_TYPED_GATHER_ND(int64_t, 12)
|
||||
|
||||
template <typename T>
|
||||
struct GatherNDComputeImpl {
|
||||
void operator()(const int64_t num_slices,
|
||||
const int64_t slice_size,
|
||||
const void* const kernel_input_data,
|
||||
void* const kernel_output_data,
|
||||
int64_t* const input_slice_offsets_data) const {
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
GatherNDImpl<CudaT>(num_slices, kernel_input_data,
|
||||
kernel_output_data, slice_size,
|
||||
input_slice_offsets_data);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TIndex>
|
||||
Status GatherND<TIndex>::ComputeInternal(OpKernelContext* context) const {
|
||||
auto input_tensor = context->Input<Tensor>(0);
|
||||
|
|
@ -169,16 +146,26 @@ Status GatherND<TIndex>::ComputeInternal(OpKernelContext* context) const {
|
|||
ORT_RETURN_IF_ERROR(CheckBatchDimensionsMatch(
|
||||
static_cast<size_t>(batch_dims_), {input_shape, indices_shape}));
|
||||
|
||||
//Output shape
|
||||
// Output shape
|
||||
std::vector<int64_t> shape(indices_shape.GetDims().begin(), indices_shape.GetDims().end() - 1);
|
||||
shape.insert(shape.end(), input_shape.GetDims().begin() + last_indices_dimension, input_shape.GetDims().end());
|
||||
|
||||
auto output_tensor = context->Output(0, TensorShape(shape));
|
||||
|
||||
//Compute
|
||||
auto status = CommonComputeKernel<TIndex>(batch_dims_, input_shape, input_tensor, output_tensor, indices_shape, indices_tensor, true);
|
||||
// Compute
|
||||
int64_t num_slices;
|
||||
int64_t slice_size;
|
||||
IAllocatorUniquePtr<int64_t> input_slice_offsets_buffer;
|
||||
ORT_RETURN_IF_ERROR(PrepareCompute<TIndex>(batch_dims_, input_shape, indices_shape, indices_tensor,
|
||||
num_slices, slice_size, input_slice_offsets_buffer));
|
||||
|
||||
return status;
|
||||
const void* const kernel_input_data = input_tensor->DataRaw();
|
||||
void* const kernel_output_data = output_tensor->MutableDataRaw();
|
||||
utils::MLTypeCallDispatcher<GatherNDComputeImpl, float, MLFloat16, double>
|
||||
t_disp(input_tensor->GetElementType());
|
||||
t_disp.Invoke(num_slices, slice_size, kernel_input_data, kernel_output_data, input_slice_offsets_buffer.get());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
|
|
|
|||
|
|
@ -23,14 +23,14 @@ class GatherNDBase : public CudaKernel {
|
|||
|
||||
protected:
|
||||
template <typename TIndex>
|
||||
Status CommonComputeKernel(
|
||||
Status PrepareCompute(
|
||||
const int64_t batch_dims,
|
||||
const TensorShape& input_shape,
|
||||
const Tensor* input_tensor,
|
||||
Tensor* output_tensor,
|
||||
const TensorShape& indices_shape,
|
||||
const Tensor* indices_tensor,
|
||||
const bool fwd) const;
|
||||
int64_t& num_slices,
|
||||
int64_t& slice_size,
|
||||
IAllocatorUniquePtr<int64_t>& input_slice_offsets_buffer) const;
|
||||
|
||||
int64_t batch_dims_;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/cuda/tensor/gather_nd_impl.h"
|
||||
|
||||
#include "core/providers/cuda/cu_inc/common.cuh"
|
||||
#include "core/providers/cuda/atomic/common.cuh"
|
||||
|
||||
|
|
|
|||
|
|
@ -37,5 +37,12 @@ inline bool HasCudaEnvironment(int min_cuda_architecture) {
|
|||
return cuda_architecture >= min_cuda_architecture;
|
||||
}
|
||||
|
||||
inline bool NeedSkipIfCudaArchLowerThan(int min_cuda_architecture) {
|
||||
// only skip when CUDA ep is enabled.
|
||||
if (DefaultCudaExecutionProvider().get() != nullptr) {
|
||||
return !HasCudaEnvironment(min_cuda_architecture);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -79,7 +79,9 @@ TEST(GatherNDOpTest, int64_t) {
|
|||
}
|
||||
|
||||
TEST(GatherNDOpTest, float) {
|
||||
if (!HasCudaEnvironment(600 /*min_cuda_architecture*/)) return;
|
||||
if (NeedSkipIfCudaArchLowerThan(600)) {
|
||||
return;
|
||||
}
|
||||
|
||||
RunTest<float>({2, 2}, {0.0f, 0.1f, 0.2f, 0.3f}, {2, 1}, {1LL, 0LL}, {2, 2}, {0.2f, 0.3f, 0.0f, 0.1f});
|
||||
|
||||
|
|
@ -88,7 +90,9 @@ TEST(GatherNDOpTest, float) {
|
|||
}
|
||||
|
||||
TEST(GatherNDOpTest, double) {
|
||||
if (!HasCudaEnvironment(600 /*min_cuda_architecture*/)) return;
|
||||
if (NeedSkipIfCudaArchLowerThan(600)) {
|
||||
return;
|
||||
}
|
||||
|
||||
RunTest<double>({2, 2}, {0.0, 0.1, 0.2, 0.3}, {2, 1}, {1LL, 0LL}, {2, 2}, {0.2, 0.3, 0.0, 0.1});
|
||||
|
||||
|
|
@ -133,7 +137,15 @@ TEST(GatherNDOpTest, ContribOpInt32Indices) {
|
|||
|
||||
#endif
|
||||
|
||||
TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_0) {
|
||||
TEST(GatherNDOpTest, GatherND_slice_float_default_batch_dims) {
|
||||
OpTester test("GatherND", 12, kOnnxDomain);
|
||||
test.AddInput<float>("data", {2, 3, 4}, ValueRange(24, 1.0f));
|
||||
test.AddInput<int64_t>("indices", {3, 2, 2}, {0LL, 1LL, 0LL, 2LL, 1LL, 0LL, 0LL, 0LL, 1LL, 1LL, 1LL, 2LL});
|
||||
test.AddOutput<float>("output", {3, 2, 4}, {5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 1.0, 2.0, 3.0, 4.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_zero) {
|
||||
OpTester test("GatherND", 12, kOnnxDomain);
|
||||
test.AddAttribute<int64_t>("batch_dims", 0);
|
||||
test.AddInput<float>("data", {2, 3, 4}, ValueRange(24, 1.0f));
|
||||
|
|
@ -142,7 +154,7 @@ TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_0) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_1) {
|
||||
TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_one_1) {
|
||||
OpTester test("GatherND", 12, kOnnxDomain);
|
||||
test.AddAttribute<int64_t>("batch_dims", 1);
|
||||
test.AddInput<float>("data", {2, 3, 4}, ValueRange(24, 1.0f));
|
||||
|
|
@ -151,7 +163,7 @@ TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_1) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_2) {
|
||||
TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_one_2) {
|
||||
OpTester test("GatherND", 12, kOnnxDomain);
|
||||
test.AddAttribute<int64_t>("batch_dims", 1);
|
||||
test.AddInput<float>("data", {2, 2, 2}, ValueRange(8, 0.0f, 0.1f));
|
||||
|
|
@ -160,28 +172,7 @@ TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_2) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#if __CUDA_ARCH__ >= 600
|
||||
TEST(GatherNDOpTest, GatherND_slice_double_batch_dims_3) {
|
||||
OpTester test("GatherND", 12, kOnnxDomain);
|
||||
test.AddAttribute<int64_t>("batch_dims", 1);
|
||||
test.AddInput<double>("data", {2, 2, 2}, ValueRange(8, 0.0f, 0.1f));
|
||||
test.AddInput<int64_t>("indices", {2, 1, 1}, {1LL, 0LL});
|
||||
test.AddOutput<double>("output", {2, 1, 2}, {0.2f, 0.3f, 0.4f, 0.5f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(GatherNDOpTest, GatherND_slice_double) {
|
||||
OpTester test("GatherND", 12, kOnnxDomain);
|
||||
test.AddInput<double>("data", {2, 2}, {0.0f, 0.1f, 0.2f, 0.3f});
|
||||
test.AddInput<int64_t>("indices", {2, 1}, {1LL, 0LL});
|
||||
test.AddOutput<double>("output", {2, 2}, {0.2f, 0.3f, 0.0f, 0.1f});
|
||||
test.Run();
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_4) {
|
||||
TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_one_3) {
|
||||
OpTester test("GatherND", 12, kOnnxDomain);
|
||||
test.AddAttribute<int64_t>("batch_dims", 1);
|
||||
test.AddInput<float>("data", {2, 2, 2}, ValueRange(8, 0.0f, 0.1f));
|
||||
|
|
@ -190,10 +181,74 @@ TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_4) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
#ifdef USE_CUDA
|
||||
TEST(GatherNDOpTest, GatherND_slice_float_batch_dims_two) {
|
||||
OpTester test("GatherND", 12, kOnnxDomain);
|
||||
test.AddAttribute<int64_t>("batch_dims", 2);
|
||||
test.AddInput<float>("data", {2, 1, 3, 5}, ValueRange(30, 0.0f, 0.1f));
|
||||
test.AddInput<int64_t>("indices", {2, 1, 3, 2},
|
||||
{0LL, 0LL,
|
||||
0LL, 1LL,
|
||||
1LL, 0LL,
|
||||
1LL, 1LL,
|
||||
0LL, 4LL,
|
||||
2LL, 4LL});
|
||||
test.AddOutput<float>("output", {2, 1, 3}, {0.0f, 0.1f, 0.5f, 2.1f, 1.9f, 2.9f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(GatherNDOpTest, GatherND_slice_double_batch_dims_3) {
|
||||
if (!HasCudaEnvironment(600 /*min_cuda_architecture*/)) return;
|
||||
TEST(GatherNDOpTest, GatherND_negative_slice_float_batch_dims_one) {
|
||||
OpTester test("GatherND", 12, kOnnxDomain);
|
||||
test.AddAttribute<int64_t>("batch_dims", 1);
|
||||
test.AddInput<float>("data", {2, 3, 4}, ValueRange(24, 1.0f));
|
||||
test.AddInput<int64_t>("indices", {2, 2, 2}, {0LL, -3LL, -1LL, 2LL, -1LL, 0LL, 0LL, -2LL});
|
||||
test.AddOutput<float>("output", {2, 2}, {2.0, 11.0, 21.0, 15.0});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(GatherNDOpTest, GatherND_negative_slice_float_batch_dims_two) {
|
||||
OpTester test("GatherND", 12, kOnnxDomain);
|
||||
test.AddAttribute<int64_t>("batch_dims", 2);
|
||||
test.AddInput<float>("data", {2, 1, 3, 5}, ValueRange(30, 0.0f, 0.1f));
|
||||
test.AddInput<int64_t>("indices", {2, 1, 3, 2},
|
||||
{0LL, -5LL,
|
||||
-3LL, 1LL,
|
||||
-2LL, 0LL,
|
||||
-2LL, -4LL,
|
||||
0LL, -1LL,
|
||||
2LL, -1LL});
|
||||
test.AddOutput<float>("output", {2, 1, 3}, {0.0f, 0.1f, 0.5f, 2.1f, 1.9f, 2.9f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(GatherNDOpTest, GatherND_slice_double_batch_dims_one_1) {
|
||||
if (NeedSkipIfCudaArchLowerThan(600)) {
|
||||
return;
|
||||
}
|
||||
|
||||
OpTester test("GatherND", 12, kOnnxDomain);
|
||||
test.AddAttribute<int64_t>("batch_dims", 1);
|
||||
test.AddInput<double>("data", {2, 2, 2}, ValueRange(8, 0.0, 0.1));
|
||||
test.AddInput<int64_t>("indices", {2, 1, 1}, {1LL, 0LL});
|
||||
test.AddOutput<double>("output", {2, 1, 2}, {0.2f, 0.3f, 0.4f, 0.5f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(GatherNDOpTest, GatherND_slice_double_default_batch_dims) {
|
||||
if (NeedSkipIfCudaArchLowerThan(600)) {
|
||||
return;
|
||||
}
|
||||
|
||||
OpTester test("GatherND", 12, kOnnxDomain);
|
||||
test.AddInput<double>("data", {2, 2}, {0.0f, 0.1f, 0.2f, 0.3f});
|
||||
test.AddInput<int64_t>("indices", {2, 1}, {1LL, 0LL});
|
||||
test.AddOutput<double>("output", {2, 2}, {0.2f, 0.3f, 0.0f, 0.1f});
|
||||
test.Run();
|
||||
} // namespace test
|
||||
|
||||
TEST(GatherNDOpTest, GatherND_slice_double_batch_dims_one_2) {
|
||||
if (NeedSkipIfCudaArchLowerThan(600)) {
|
||||
return;
|
||||
}
|
||||
|
||||
OpTester test("GatherND", 12, kOnnxDomain);
|
||||
test.AddAttribute<int64_t>("batch_dims", 1);
|
||||
|
|
@ -204,7 +259,9 @@ TEST(GatherNDOpTest, GatherND_slice_double_batch_dims_3) {
|
|||
}
|
||||
|
||||
TEST(GatherNDOpTest, GatherND_slice_half) {
|
||||
if (!HasCudaEnvironment(600 /*min_cuda_architecture*/)) return;
|
||||
if (NeedSkipIfCudaArchLowerThan(600)) {
|
||||
return;
|
||||
}
|
||||
|
||||
OpTester test("GatherND", 12, kOnnxDomain);
|
||||
std::vector<float> data_f({0.0f, 0.1f, 0.2f, 0.3f});
|
||||
|
|
@ -242,7 +299,5 @@ TEST(GatherNDOpTest, GatherND_batch_dims_of_2) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "orttraining/training_ops/cuda/tensor/gather_nd_grad.h"
|
||||
#include "orttraining/training_ops/cuda/tensor/gather_nd_grad_impl.h"
|
||||
#include "core/providers/cuda/shared_inc/cuda_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
@ -23,6 +24,20 @@ namespace cuda {
|
|||
|
||||
REGISTER_KERNEL_TYPED_GATHER_ND_GRAD(int64_t)
|
||||
|
||||
template <typename T>
|
||||
struct GatherNDGradComputeImpl {
|
||||
void operator()(const int64_t num_slices,
|
||||
const int64_t slice_size,
|
||||
const void* const kernel_input_data,
|
||||
void* const kernel_output_data,
|
||||
int64_t* const input_slice_offsets_data) const {
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
GatherNDGradImpl<CudaT>(num_slices, kernel_input_data,
|
||||
kernel_output_data, slice_size,
|
||||
input_slice_offsets_data);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TIndex>
|
||||
Status GatherNDGrad<TIndex>::ComputeInternal(OpKernelContext* context) const {
|
||||
auto shape_tensor = context->Input<Tensor>(0);
|
||||
|
|
@ -42,7 +57,7 @@ Status GatherNDGrad<TIndex>::ComputeInternal(OpKernelContext* context) const {
|
|||
|
||||
auto last_indices_dimension = batch_dims_ + indices_shape[indices_shape.NumDimensions() - 1];
|
||||
|
||||
//Output
|
||||
// Output
|
||||
auto shape_data = shape_tensor->Data<int64_t>();
|
||||
auto input_shape = TensorShape(shape_data, shape_tensor->SizeInBytes() / sizeof(shape_tensor->DataType()));
|
||||
|
||||
|
|
@ -59,8 +74,20 @@ Status GatherNDGrad<TIndex>::ComputeInternal(OpKernelContext* context) const {
|
|||
// TODO this memset can be expensive, a sparse tensor representation would help here
|
||||
CUDA_RETURN_IF_ERROR(cudaMemsetAsync(output_tensor->MutableDataRaw(), 0, output_tensor->SizeInBytes()));
|
||||
|
||||
auto status = CommonComputeKernel<TIndex>(batch_dims_, input_shape, update_tensor, output_tensor, indices_shape, indices_tensor, false);
|
||||
return status;
|
||||
// Compute
|
||||
int64_t num_slices;
|
||||
int64_t slice_size;
|
||||
IAllocatorUniquePtr<int64_t> input_slice_offsets_buffer;
|
||||
ORT_RETURN_IF_ERROR(PrepareCompute<TIndex>(batch_dims_, input_shape, indices_shape, indices_tensor,
|
||||
num_slices, slice_size, input_slice_offsets_buffer));
|
||||
|
||||
const void* const kernel_input_data = update_tensor->DataRaw();
|
||||
void* const kernel_output_data = output_tensor->MutableDataRaw();
|
||||
utils::MLTypeCallDispatcher<GatherNDGradComputeImpl, float, MLFloat16, double>
|
||||
t_disp(update_tensor->GetElementType());
|
||||
t_disp.Invoke(num_slices, slice_size, kernel_input_data, kernel_output_data, input_slice_offsets_buffer.get());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/cuda/tensor/gather_nd_impl.h"
|
||||
#include "orttraining/training_ops/cuda/tensor/gather_nd_grad_impl.h"
|
||||
|
||||
#include "core/providers/cuda/cu_inc/common.cuh"
|
||||
#include "core/providers/cuda/atomic/common.cuh"
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include "core/providers/cuda/shared_inc/cuda_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
template <typename T>
|
||||
void GatherNDGradImpl(
|
||||
const size_t num_slices,
|
||||
const void* update_data,
|
||||
void* output_data,
|
||||
const size_t slice_size,
|
||||
const int64_t* input_slice_offsets_data);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue