mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Implement ScatterND for CUDA EP (#6184)
This commit is contained in:
parent
945fae8f56
commit
fc27074bae
7 changed files with 307 additions and 33 deletions
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
#include "scatter_nd.h"
|
||||
#include "core/platform/threadpool.h"
|
||||
#include "core/providers/cpu/tensor/utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
|
|
@ -11,73 +12,87 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
|
|||
11,
|
||||
12,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
|
||||
.TypeConstraint("Tind", DataTypeImpl::GetTensorType<int64_t>()),
|
||||
.TypeConstraint("T", DataTypeImpl::AllTensorTypes()),
|
||||
ScatterND);
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
ScatterND,
|
||||
13,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
|
||||
.TypeConstraint("Tind", DataTypeImpl::GetTensorType<int64_t>()),
|
||||
.TypeConstraint("T", DataTypeImpl::AllTensorTypes()),
|
||||
ScatterND);
|
||||
|
||||
template <typename Tind>
|
||||
Status ScatterNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) const {
|
||||
const auto* input_tensor = context->Input<Tensor>(0);
|
||||
const auto* indice_tensor = context->Input<Tensor>(1);
|
||||
const auto* update_tensor = context->Input<Tensor>(2);
|
||||
ORT_ENFORCE(input_tensor != nullptr);
|
||||
ORT_ENFORCE(indice_tensor != nullptr);
|
||||
ORT_ENFORCE(update_tensor != nullptr);
|
||||
Status ScatterNDBase::ValidateShapes(const TensorShape& input_shape,
|
||||
const TensorShape& indice_shape,
|
||||
const TensorShape& update_shape) {
|
||||
auto input_rank = input_shape.NumDimensions();
|
||||
auto indice_rank = indice_shape.NumDimensions();
|
||||
auto update_rank = update_shape.NumDimensions();
|
||||
|
||||
const auto& input_shape = input_tensor->Shape();
|
||||
const auto& indice_shape = indice_tensor->Shape();
|
||||
const auto& update_shape = update_tensor->Shape();
|
||||
if (indice_shape.NumDimensions() == 0 || input_shape.NumDimensions() == 0) {
|
||||
if (input_rank == 0 || indice_rank == 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"input tensor and indices tensor must has rank larger than 0. ",
|
||||
"input shape: ", input_shape, ", indices shape: ", indice_shape);
|
||||
}
|
||||
|
||||
auto indice_rank = indice_shape.NumDimensions();
|
||||
auto last_indice_dimension = indice_shape[indice_rank - 1];
|
||||
if (last_indice_dimension > static_cast<int64_t>(input_shape.NumDimensions())) {
|
||||
if (last_indice_dimension > static_cast<int64_t>(input_rank)) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"last dimension of indices must not be larger than rank of input tensor");
|
||||
}
|
||||
|
||||
bool is_update_shape_invalid = [&]() {
|
||||
auto update_rank = update_shape.NumDimensions();
|
||||
auto input_rank = input_shape.NumDimensions();
|
||||
if (update_rank < indice_rank - 1) {
|
||||
// Validate rank of update tensor
|
||||
// Per spec, the rank of the update tensor should be:
|
||||
// (Rank of input tensor) + (Rank of indices tensor) -1 - last_indice_dimension
|
||||
if (update_rank != (input_rank + indice_rank - 1 - static_cast<int64_t>(last_indice_dimension))) {
|
||||
return true;
|
||||
}
|
||||
if (update_rank >= indice_rank - 1 &&
|
||||
indice_rank >= 1 &&
|
||||
(indice_shape.Slice(0, indice_rank - 1) != update_shape.Slice(0, indice_rank - 1))) {
|
||||
|
||||
// Validate shape of the update tensor
|
||||
// Part 1: The shape of the update tensor upto the indices rank - 1 (exclusive)
|
||||
// should match the shape of the indices tensor upto indices rank - 1 (exclusive)
|
||||
if (indice_shape.Slice(0, indice_rank - 1) != update_shape.Slice(0, indice_rank - 1)) {
|
||||
return true;
|
||||
}
|
||||
if ((static_cast<int64_t>(input_rank) > last_indice_dimension) &&
|
||||
(update_rank >= indice_rank - 1) &&
|
||||
(input_shape.Slice(last_indice_dimension) != update_shape.Slice(indice_rank - 1))) {
|
||||
|
||||
// Part 2: The shape of the update tensor after indices rank - 1 (inclusive)
|
||||
// should match the shape of the input tensor after `last_indice_dimension`
|
||||
if (input_shape.Slice(last_indice_dimension) != update_shape.Slice(indice_rank - 1)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}();
|
||||
|
||||
if (is_update_shape_invalid) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"updates tensor should have shape equal to indices.shape[:-1] + data.shape[indices.shape[-1]:]. ",
|
||||
"updates shape: ", update_shape, ", indices shape: ", indice_shape, ", data shape: ", input_shape);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ScatterNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) const {
|
||||
const auto* input_tensor = context->Input<Tensor>(0);
|
||||
const auto* indice_tensor = context->Input<Tensor>(1);
|
||||
const auto* update_tensor = context->Input<Tensor>(2);
|
||||
|
||||
const auto& input_shape = input_tensor->Shape();
|
||||
const auto& indice_shape = indice_tensor->Shape();
|
||||
const auto& update_shape = update_tensor->Shape();
|
||||
|
||||
ORT_RETURN_IF_ERROR(ValidateShapes(input_shape, indice_shape, update_shape));
|
||||
|
||||
auto output_tensor = context->Output(0, input_shape);
|
||||
|
||||
const auto* src_base = input_tensor->DataRaw();
|
||||
auto* dst_base = output_tensor->MutableDataRaw();
|
||||
const bool is_string_type = input_tensor->IsDataTypeString();
|
||||
|
||||
auto last_indice_dimension = indice_shape[indice_shape.NumDimensions() - 1];
|
||||
|
||||
// Re-use input for output. If input/output Tensor* are the same, do not copy.
|
||||
if (src_base != dst_base) {
|
||||
if (is_string_type) {
|
||||
|
|
@ -92,15 +107,16 @@ Status ScatterNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) co
|
|||
|
||||
std::vector<int64_t> element_counts(last_indice_dimension, 0LL); // Number of elements for each input dimension
|
||||
|
||||
TensorPitches input_strides(input_shape);
|
||||
for (int64_t i = 0; i < last_indice_dimension; ++i) {
|
||||
element_counts[i] = input_shape.SizeFromDimension(i + 1);
|
||||
element_counts[i] = input_strides[i];
|
||||
}
|
||||
|
||||
int64_t err_indice = 0;
|
||||
p.element_bytes = input_tensor->DataType()->Size();
|
||||
p.element_to_copy = input_shape.SizeFromDimension(last_indice_dimension);
|
||||
p.bytes_to_copy = p.element_bytes * p.element_to_copy;
|
||||
auto indice_offset = static_cast<const Tind*>(indice_tensor->DataRaw());
|
||||
const int64_t* indice_offset = indice_tensor->template Data<int64_t>();
|
||||
auto offset_count = indice_shape.Size() / last_indice_dimension; // Times to copy
|
||||
p.element_offsets.assign(offset_count, 0LL);
|
||||
|
||||
|
|
@ -124,12 +140,10 @@ Status ScatterNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) co
|
|||
return err_indice == 0 ? Status::OK() : ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid indice found, indice = ", err_indice);
|
||||
}
|
||||
|
||||
template Status ScatterNDBase::PrepareForCompute<int64_t>(OpKernelContext*, Prepare&) const;
|
||||
|
||||
Status ScatterND::Compute(OpKernelContext* context) const {
|
||||
Prepare p;
|
||||
concurrency::ThreadPool* tp = context->GetOperatorThreadPool();
|
||||
ORT_RETURN_IF_ERROR(PrepareForCompute<int64_t>(context, p));
|
||||
ORT_RETURN_IF_ERROR(PrepareForCompute(context, p));
|
||||
return nullptr == p.input_str_base ? ScatterNumber(p, tp) : ScatterString(p, tp);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -32,7 +32,11 @@ class ScatterNDBase {
|
|||
element_offsets(0) {}
|
||||
}; // struct Prepare
|
||||
|
||||
template <typename Tind>
|
||||
// Shared between the CPU and CUDA implementation
|
||||
static Status ValidateShapes(const TensorShape& input_shape,
|
||||
const TensorShape& indice_shape,
|
||||
const TensorShape& update_shape);
|
||||
|
||||
Status PrepareForCompute(OpKernelContext* context, Prepare& p) const;
|
||||
}; // class ScatterNDBase
|
||||
|
||||
|
|
|
|||
|
|
@ -759,6 +759,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, int32_t_float_int32_t, OneHot);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, int64_t_MLFloat16_int64_t, OneHot);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, int32_t_MLFloat16_int32_t, OneHot);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, ScatterND);
|
||||
|
||||
// OpSet 12
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, Clip);
|
||||
|
|
@ -988,6 +989,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, LRN);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, LRN);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Identity);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, ScatterND);
|
||||
|
||||
template <>
|
||||
KernelCreateInfo BuildKernelCreateInfo<void>() {
|
||||
|
|
@ -1462,6 +1464,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, int32_t_float_int32_t, OneHot)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, int64_t_MLFloat16_int64_t, OneHot)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, int32_t_MLFloat16_int32_t, OneHot)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, ScatterND)>,
|
||||
|
||||
// OpSet 12
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, Clip)>,
|
||||
|
|
@ -1691,6 +1694,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, LRN)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, LRN)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Identity)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, ScatterND)>,
|
||||
};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
|
|
|
|||
87
onnxruntime/core/providers/cuda/tensor/scatter_nd.cc
Normal file
87
onnxruntime/core/providers/cuda/tensor/scatter_nd.cc
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/cuda/tensor/scatter_nd.h"
|
||||
#include "core/providers/cuda/tensor/scatter_nd_impl.h"
|
||||
#include "core/providers/cuda/shared_inc/cuda_utils.h"
|
||||
#include "core/providers/cpu/tensor/utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterND,
|
||||
kOnnxDomain,
|
||||
11, 12,
|
||||
kCudaExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
|
||||
.MayInplace(0, 0),
|
||||
ScatterND);
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(ScatterND,
|
||||
kOnnxDomain,
|
||||
13,
|
||||
kCudaExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
|
||||
.MayInplace(0, 0),
|
||||
ScatterND);
|
||||
|
||||
Status ScatterND::ComputeInternal(OpKernelContext* context) const {
|
||||
const auto* input_tensor = context->Input<Tensor>(0);
|
||||
const auto* indices_tensor = context->Input<Tensor>(1);
|
||||
const auto* updates_tensor = context->Input<Tensor>(2);
|
||||
|
||||
const auto& input_shape = input_tensor->Shape();
|
||||
const auto& indices_shape = indices_tensor->Shape();
|
||||
const auto& updates_shape = updates_tensor->Shape();
|
||||
|
||||
// Validate input shapes
|
||||
ValidateShapes(input_shape, indices_shape, updates_shape);
|
||||
|
||||
auto* output_tensor = context->Output(0, input_shape);
|
||||
|
||||
const void* input_data = input_tensor->DataRaw();
|
||||
void* output_data = output_tensor->MutableDataRaw();
|
||||
|
||||
size_t element_size = input_tensor->DataType()->Size();
|
||||
|
||||
if (input_data != output_data) {
|
||||
// TODO: Run benchmarks to determine if a dedicated kernel doing data copy will be faster than invoking cudaMemcpy ?
|
||||
cudaMemcpyAsync(output_data, input_data, element_size * input_shape.Size(), cudaMemcpyDeviceToDevice);
|
||||
}
|
||||
|
||||
// Bail out early
|
||||
if (indices_shape.Size() == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
auto last_index_dimension = indices_shape[indices_shape.NumDimensions() - 1];
|
||||
|
||||
// We need element counts for each dimension and the input dim value for each dimension
|
||||
// for the range [0, last_index_dimension).
|
||||
// To avoid multiple GPU data transfers, we combine this into one array and send it through
|
||||
TensorPitches input_strides(input_shape);
|
||||
std::vector<int64_t> element_counts_and_input_dims(last_index_dimension * 2, 0LL);
|
||||
for (int64_t i = 0; i < last_index_dimension; ++i) {
|
||||
element_counts_and_input_dims[i] = input_strides[i];
|
||||
element_counts_and_input_dims[i + last_index_dimension] = input_shape[i];
|
||||
}
|
||||
CudaAsyncBuffer<int64_t> element_counts_and_input_dims_gpu(this, element_counts_and_input_dims);
|
||||
element_counts_and_input_dims_gpu.CopyToGpu();
|
||||
|
||||
ORT_RETURN_IF_ERROR(ScatterNDImpl(
|
||||
output_data,
|
||||
element_size,
|
||||
indices_shape.Size() / static_cast<size_t>(last_index_dimension),
|
||||
indices_tensor->Data<int64_t>(), // only int64_t is supported for indices as per the onnx spec
|
||||
last_index_dimension,
|
||||
element_counts_and_input_dims_gpu.GpuPtr(),
|
||||
updates_tensor->DataRaw(),
|
||||
input_shape.SizeFromDimension(last_index_dimension)));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
20
onnxruntime/core/providers/cuda/tensor/scatter_nd.h
Normal file
20
onnxruntime/core/providers/cuda/tensor/scatter_nd.h
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/providers/cuda/cuda_kernel.h"
|
||||
#include "core/providers/cpu/tensor/scatter_nd.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
class ScatterND final : public CudaKernel, protected ScatterNDBase {
|
||||
public:
|
||||
explicit ScatterND(const OpKernelInfo& info) : CudaKernel(info) {}
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
123
onnxruntime/core/providers/cuda/tensor/scatter_nd_impl.cu
Normal file
123
onnxruntime/core/providers/cuda/tensor/scatter_nd_impl.cu
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/cuda/tensor/scatter_nd_impl.h"
|
||||
#include "core/providers/cuda/cu_inc/common.cuh"
|
||||
#include "core/providers/cuda/atomic/common.cuh"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
template <typename T>
|
||||
__global__ void _ScatterNDKernel(
|
||||
T* output_data,
|
||||
const size_t num_indices,
|
||||
const int64_t* indices_data,
|
||||
const int64_t last_index_dimension,
|
||||
const int64_t* element_counts_and_input_dims,
|
||||
const T* updates_data,
|
||||
const size_t num_updates_elements) {
|
||||
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, num_indices);
|
||||
|
||||
// Compute the base offset into the output data
|
||||
int64_t data_offset = 0;
|
||||
|
||||
size_t indices_start = last_index_dimension * id;
|
||||
size_t indices_end = indices_start + last_index_dimension;
|
||||
for (size_t i = indices_start; i < indices_end; ++i) {
|
||||
int64_t index = indices_data[i];
|
||||
|
||||
int64_t element_count_dim = element_counts_and_input_dims[i - indices_start];
|
||||
int64_t dim_value = element_counts_and_input_dims[i - indices_start + last_index_dimension];
|
||||
|
||||
// Clamp the index if out of range
|
||||
// This would have been an error in the CPU kernel, but throwing in the CUDA EP
|
||||
// is hard. This is the approach taken by other frameworks for out of bound indices
|
||||
// in their corresponding GPU backends as well.
|
||||
if (index < 0)
|
||||
index = 0;
|
||||
|
||||
else if (index >= dim_value)
|
||||
index = dim_value - 1;
|
||||
|
||||
data_offset += (index * element_count_dim);
|
||||
}
|
||||
|
||||
const T* updates_data_base = updates_data + num_updates_elements * id;
|
||||
T* output_data_base = output_data + data_offset;
|
||||
|
||||
for (size_t i = 0; i < num_updates_elements; ++i) {
|
||||
output_data_base[i] = updates_data_base[i];
|
||||
}
|
||||
}
|
||||
|
||||
Status ScatterNDImpl(
|
||||
void* output_data,
|
||||
const size_t element_size,
|
||||
const size_t num_indices,
|
||||
const int64_t* indices_data,
|
||||
const int64_t last_index_dimension,
|
||||
const int64_t* element_counts_and_input_dims,
|
||||
const void* updates_data,
|
||||
const size_t num_updates_elements) {
|
||||
if (num_indices == 0)
|
||||
return Status::OK();
|
||||
|
||||
// Parallelize on number of indices
|
||||
int blocksPerGrid = static_cast<int>(ceil(static_cast<float>(num_indices) / GridDim::maxThreadsPerBlock));
|
||||
|
||||
switch (element_size) {
|
||||
case sizeof(int8_t):
|
||||
_ScatterNDKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
reinterpret_cast<int8_t*>(output_data),
|
||||
num_indices,
|
||||
indices_data,
|
||||
last_index_dimension,
|
||||
element_counts_and_input_dims,
|
||||
reinterpret_cast<const int8_t*>(updates_data),
|
||||
num_updates_elements);
|
||||
break;
|
||||
|
||||
case sizeof(int16_t):
|
||||
_ScatterNDKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
reinterpret_cast<int16_t*>(output_data),
|
||||
num_indices,
|
||||
indices_data,
|
||||
last_index_dimension,
|
||||
element_counts_and_input_dims,
|
||||
reinterpret_cast<const int16_t*>(updates_data),
|
||||
num_updates_elements);
|
||||
break;
|
||||
|
||||
case sizeof(int32_t):
|
||||
_ScatterNDKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
reinterpret_cast<int32_t*>(output_data),
|
||||
num_indices,
|
||||
indices_data,
|
||||
last_index_dimension,
|
||||
element_counts_and_input_dims,
|
||||
reinterpret_cast<const int32_t*>(updates_data),
|
||||
num_updates_elements);
|
||||
break;
|
||||
|
||||
case sizeof(int64_t):
|
||||
_ScatterNDKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
reinterpret_cast<int64_t*>(output_data),
|
||||
num_indices,
|
||||
indices_data,
|
||||
last_index_dimension,
|
||||
element_counts_and_input_dims,
|
||||
reinterpret_cast<const int64_t*>(updates_data),
|
||||
num_updates_elements);
|
||||
break;
|
||||
|
||||
default:
|
||||
// Shouldn't hit this
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for ScatterND operator");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
22
onnxruntime/core/providers/cuda/tensor/scatter_nd_impl.h
Normal file
22
onnxruntime/core/providers/cuda/tensor/scatter_nd_impl.h
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/cuda/shared_inc/cuda_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
Status ScatterNDImpl(
|
||||
void* output_data,
|
||||
const size_t element_size,
|
||||
const size_t num_indices,
|
||||
const int64_t* indices_data,
|
||||
const int64_t last_index_dimension,
|
||||
const int64_t* element_counts_and_input_dims,
|
||||
const void* updates_data,
|
||||
const size_t num_updates_elements);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue