mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-03 23:49:44 +00:00
Add implementation for ScatterND (#19540)
### Description onnxruntime switches to CPU for ScatterND after opset 13. This extends the implementation of higher opsets.
This commit is contained in:
parent
14fcf0a52d
commit
80213a9e66
15 changed files with 869 additions and 42 deletions
|
|
@ -774,7 +774,9 @@ Do not modify directly.*
|
|||
|||[16, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|
||||
|||[13, 15]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|
||||
|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|
||||
|ScatterND|*in* data:**T**<br> *in* indices:**tensor(int64)**<br> *in* updates:**T**<br> *out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|ScatterND|*in* data:**T**<br> *in* indices:**tensor(int64)**<br> *in* updates:**T**<br> *out* output:**T**|18+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||[16, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||[13, 15]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|Selu|*in* X:**T**<br> *out* Y:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|SequenceAt|*in* input_sequence:**S**<br> *in* position:**I**<br> *out* tensor:**T**|11+|**I** = tensor(int32), tensor(int64)<br/> **S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))<br/> **T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|
|
|
|||
|
|
@ -1157,7 +1157,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, LRN);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, LRN);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, Identity);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, ScatterND);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 15, ScatterND);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, float, Pad);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, double, Pad);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, MLFloat16, Pad);
|
||||
|
|
@ -1295,6 +1295,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, double, LessOrEqual);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterElements);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterND);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, GridSample);
|
||||
|
||||
// Opset 17
|
||||
|
|
@ -1312,6 +1313,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMax);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMax);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterElements);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterND);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, Pad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Pad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad);
|
||||
|
|
@ -2071,7 +2073,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, LRN)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, LRN)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, Identity)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, ScatterND)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 15, ScatterND)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, float, Pad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, double, Pad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, MLFloat16, Pad)>,
|
||||
|
|
@ -2202,6 +2204,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, double, LessOrEqual)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterElements)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterND)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, GridSample)>,
|
||||
|
||||
// Opset 17
|
||||
|
|
@ -2225,6 +2228,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterElements)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterND)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, Pad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Pad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad)>,
|
||||
|
|
|
|||
|
|
@ -133,7 +133,7 @@ Status ScatterElements::ComputeInternal(OpKernelContext* context) const {
|
|||
} else if (reduction_ == "max") {
|
||||
args.operation = GatherScatterElementsArgs::Operation::MAX;
|
||||
} else {
|
||||
ORT_THROW("Unsupported reduction type");
|
||||
ORT_THROW("Unsupported reduction type for ScatterElements.");
|
||||
}
|
||||
|
||||
// Use element size instead of concrete types so we can specialize less template functions to reduce binary size.
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
#include "core/providers/cuda/tensor/scatter_nd.h"
|
||||
#include "core/providers/cuda/tensor/scatter_nd_impl.h"
|
||||
#include "core/providers/cuda/tensor/scatter_nd_common.h"
|
||||
#include "core/providers/cuda/shared_inc/cuda_utils.h"
|
||||
#include "core/providers/cpu/tensor/utils.h"
|
||||
|
||||
|
|
@ -16,18 +17,61 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterND,
|
|||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
|
||||
.MayInplace(0, 0),
|
||||
ScatterND);
|
||||
ScatterNDDisjointAndNoReduction);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterND,
|
||||
kOnnxDomain,
|
||||
13, 15,
|
||||
kCudaExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
|
||||
.MayInplace(0, 0),
|
||||
ScatterNDWithAtomicReduction);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterND,
|
||||
kOnnxDomain,
|
||||
16, 17,
|
||||
kCudaExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
|
||||
.MayInplace(0, 0),
|
||||
ScatterNDWithAtomicReduction);
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(ScatterND,
|
||||
kOnnxDomain,
|
||||
13,
|
||||
18,
|
||||
kCudaExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
|
||||
.MayInplace(0, 0),
|
||||
ScatterND);
|
||||
ScatterNDWithAtomicReduction);
|
||||
|
||||
Status ScatterND::ComputeInternal(OpKernelContext* context) const {
|
||||
static Status InitiliazeElementCountsAndInputDimsSpanOrGpu(int64_t last_index_dimension, const TensorShape& input_shape,
|
||||
ElementCountsAndInputDimsSpanOrGpu& element_counts_and_input_dims,
|
||||
CudaKernel::CudaAsyncBuffer<int64_t>& element_counts_and_input_dims_gpu,
|
||||
onnxruntime::OpKernelContext* context) {
|
||||
TensorPitches input_strides(input_shape);
|
||||
|
||||
if (last_index_dimension < 6) {
|
||||
element_counts_and_input_dims.gpu_ptr = nullptr;
|
||||
for (int64_t i = 0; i < last_index_dimension; ++i) {
|
||||
element_counts_and_input_dims.stack_ptr[i] = input_strides[i];
|
||||
element_counts_and_input_dims.stack_ptr[i + last_index_dimension] = input_shape[i];
|
||||
}
|
||||
} else {
|
||||
element_counts_and_input_dims_gpu.AllocCpuPtr(last_index_dimension * 2);
|
||||
memset(element_counts_and_input_dims_gpu.CpuPtr(), 0, sizeof(int64_t) * last_index_dimension * 2);
|
||||
for (int64_t i = 0; i < last_index_dimension; ++i) {
|
||||
element_counts_and_input_dims_gpu.CpuPtr()[i] = input_strides[i];
|
||||
element_counts_and_input_dims_gpu.CpuPtr()[i + last_index_dimension] = input_shape[i];
|
||||
}
|
||||
ORT_RETURN_IF_ERROR(element_counts_and_input_dims_gpu.CopyToGpu(context->GetComputeStream()));
|
||||
element_counts_and_input_dims.gpu_ptr = element_counts_and_input_dims_gpu.GpuPtr();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ScatterNDDisjointAndNoReduction::ComputeInternal(OpKernelContext* context) const {
|
||||
const auto* input_tensor = context->Input<Tensor>(0);
|
||||
const auto* indices_tensor = context->Input<Tensor>(1);
|
||||
const auto* updates_tensor = context->Input<Tensor>(2);
|
||||
|
|
@ -44,8 +88,6 @@ Status ScatterND::ComputeInternal(OpKernelContext* context) const {
|
|||
const void* input_data = input_tensor->DataRaw();
|
||||
void* output_data = output_tensor->MutableDataRaw();
|
||||
|
||||
size_t element_size = input_tensor->DataType()->Size();
|
||||
|
||||
if (input_data != output_data) {
|
||||
// TODO: Run benchmarks to determine if a dedicated kernel doing data copy will be faster than invoking cudaMemcpy ?
|
||||
CUDA_RETURN_IF_ERROR(
|
||||
|
|
@ -58,18 +100,17 @@ Status ScatterND::ComputeInternal(OpKernelContext* context) const {
|
|||
}
|
||||
|
||||
auto last_index_dimension = indices_shape[indices_shape.NumDimensions() - 1];
|
||||
size_t element_size = input_tensor->DataType()->Size();
|
||||
|
||||
// We need element counts for each dimension and the input dim value for each dimension
|
||||
// for the range [0, last_index_dimension).
|
||||
// To avoid multiple GPU data transfers, we combine this into one array and send it through
|
||||
TensorPitches input_strides(input_shape);
|
||||
std::vector<int64_t> element_counts_and_input_dims(last_index_dimension * 2, 0LL);
|
||||
for (int64_t i = 0; i < last_index_dimension; ++i) {
|
||||
element_counts_and_input_dims[i] = input_strides[i];
|
||||
element_counts_and_input_dims[i + last_index_dimension] = input_shape[i];
|
||||
}
|
||||
CudaAsyncBuffer<int64_t> element_counts_and_input_dims_gpu(this, element_counts_and_input_dims);
|
||||
ORT_RETURN_IF_ERROR(element_counts_and_input_dims_gpu.CopyToGpu(context->GetComputeStream()));
|
||||
ElementCountsAndInputDimsSpanOrGpu element_counts_and_input_dims;
|
||||
CudaAsyncBuffer<int64_t> element_counts_and_input_dims_gpu(this);
|
||||
ORT_RETURN_IF_ERROR(InitiliazeElementCountsAndInputDimsSpanOrGpu(last_index_dimension, input_shape,
|
||||
element_counts_and_input_dims,
|
||||
element_counts_and_input_dims_gpu,
|
||||
context));
|
||||
|
||||
ORT_RETURN_IF_ERROR(ScatterNDImpl(
|
||||
Stream(context),
|
||||
|
|
@ -78,12 +119,89 @@ Status ScatterND::ComputeInternal(OpKernelContext* context) const {
|
|||
indices_shape.Size() / static_cast<size_t>(last_index_dimension),
|
||||
indices_tensor->Data<int64_t>(), // only int64_t is supported for indices as per the onnx spec
|
||||
last_index_dimension,
|
||||
element_counts_and_input_dims_gpu.GpuPtr(),
|
||||
element_counts_and_input_dims,
|
||||
updates_tensor->DataRaw(),
|
||||
input_shape.SizeFromDimension(last_index_dimension)));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ScatterNDWithAtomicReduction::ComputeInternal(OpKernelContext* context) const {
|
||||
const auto* input_tensor = context->Input<Tensor>(0);
|
||||
const auto* indices_tensor = context->Input<Tensor>(1);
|
||||
const auto* updates_tensor = context->Input<Tensor>(2);
|
||||
|
||||
const auto& input_shape = input_tensor->Shape();
|
||||
const auto& indices_shape = indices_tensor->Shape();
|
||||
const auto& updates_shape = updates_tensor->Shape();
|
||||
|
||||
// Validate input shapes
|
||||
ORT_RETURN_IF_ERROR(onnxruntime::ScatterND::ValidateShapes(input_shape, indices_shape, updates_shape));
|
||||
|
||||
auto* output_tensor = context->Output(0, input_shape);
|
||||
|
||||
const void* input_data = input_tensor->DataRaw();
|
||||
void* output_data = output_tensor->MutableDataRaw();
|
||||
|
||||
if (input_data != output_data) {
|
||||
// TODO: Run benchmarks to determine if a dedicated kernel doing data copy will
|
||||
// be faster than invoking cudaMemcpy ?
|
||||
CUDA_RETURN_IF_ERROR(
|
||||
cudaMemcpyAsync(output_data, input_data, input_tensor->SizeInBytes(),
|
||||
cudaMemcpyDeviceToDevice, Stream(context)));
|
||||
}
|
||||
|
||||
// Bail out early
|
||||
if (indices_shape.Size() == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
auto last_index_dimension = indices_shape[indices_shape.NumDimensions() - 1];
|
||||
ElementCountsAndInputDimsSpanOrGpu element_counts_and_input_dims;
|
||||
CudaAsyncBuffer<int64_t> element_counts_and_input_dims_gpu(this);
|
||||
ORT_RETURN_IF_ERROR(InitiliazeElementCountsAndInputDimsSpanOrGpu(last_index_dimension, input_shape,
|
||||
element_counts_and_input_dims,
|
||||
element_counts_and_input_dims_gpu,
|
||||
context));
|
||||
|
||||
switch (reduction_) {
|
||||
case ScatterNDReduction::None: {
|
||||
size_t element_size = input_tensor->DataType()->Size();
|
||||
ORT_RETURN_IF_ERROR(ScatterNDImpl(
|
||||
Stream(context),
|
||||
output_data,
|
||||
element_size,
|
||||
indices_shape.Size() / static_cast<size_t>(last_index_dimension),
|
||||
indices_tensor->Data<int64_t>(), // only int64_t is supported for indices as per the onnx spec
|
||||
last_index_dimension,
|
||||
element_counts_and_input_dims,
|
||||
updates_tensor->DataRaw(),
|
||||
input_shape.SizeFromDimension(last_index_dimension)));
|
||||
} break;
|
||||
case ScatterNDReduction::Add:
|
||||
case ScatterNDReduction::Min:
|
||||
case ScatterNDReduction::Max:
|
||||
case ScatterNDReduction::Mul: {
|
||||
auto element_type = input_tensor->DataType()->AsPrimitiveDataType()->GetDataType();
|
||||
ORT_RETURN_IF_ERROR(ScatterNDImplReduction(
|
||||
Stream(context),
|
||||
output_data,
|
||||
element_type,
|
||||
indices_shape.Size() / static_cast<size_t>(last_index_dimension),
|
||||
indices_tensor->Data<int64_t>(), // only int64_t is supported for indices as per the onnx spec
|
||||
last_index_dimension,
|
||||
element_counts_and_input_dims,
|
||||
updates_tensor->DataRaw(),
|
||||
input_shape.SizeFromDimension(last_index_dimension),
|
||||
reduction_));
|
||||
} break;
|
||||
default:
|
||||
ORT_THROW("ScatterND not supported for other reduction than Add, None.");
|
||||
break;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -3,18 +3,63 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "core/providers/shared_library/provider_api.h"
|
||||
#include "core/providers/cuda/cuda_kernel.h"
|
||||
#include "core/providers/cuda/tensor/scatter_nd_kind.h"
|
||||
#include "core/providers/cpu/tensor/scatter_nd.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
class ScatterND final : public CudaKernel {
|
||||
/**
|
||||
* This implementation assumes there is common indices and
|
||||
* reduction is not needed. The code does not check that condition.
|
||||
* However in that case, the same output element could be accessed
|
||||
* from different threads at the same time and the final value
|
||||
* is unlikely to be correct.
|
||||
*/
|
||||
class ScatterNDDisjointAndNoReduction final : public CudaKernel {
|
||||
public:
|
||||
explicit ScatterND(const OpKernelInfo& info) : CudaKernel(info) {}
|
||||
explicit ScatterNDDisjointAndNoReduction(const OpKernelInfo& info) : CudaKernel(info) {}
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
/**
|
||||
* This is an implementation derived from the first one.
|
||||
* It does atomic operation to handle conflicts.
|
||||
* The result is unlikely to be correct if the reduction is none
|
||||
* as there is no guarantee that the final value will be the one
|
||||
* corresponding to the highest visited index.
|
||||
* TODO: change the implementation of avoid conflicts.
|
||||
*/
|
||||
class ScatterNDWithAtomicReduction final : public CudaKernel {
|
||||
public:
|
||||
explicit ScatterNDWithAtomicReduction(const OpKernelInfo& info) : CudaKernel(info) {
|
||||
std::string reduction;
|
||||
|
||||
if (info.GetAttr<std::string>("reduction", &reduction).IsOK()) {
|
||||
if (reduction == "add") {
|
||||
reduction_ = ScatterNDReduction::Add;
|
||||
} else if (reduction == "mul") {
|
||||
reduction_ = ScatterNDReduction::Mul;
|
||||
} else if (reduction == "min") {
|
||||
reduction_ = ScatterNDReduction::Min;
|
||||
} else if (reduction == "max") {
|
||||
reduction_ = ScatterNDReduction::Max;
|
||||
} else if (reduction == "none") {
|
||||
LOGS_DEFAULT(WARNING) << "ScatterND with reduction=='none' only guarantees "
|
||||
<< "to be correct if indices are not duplicated.";
|
||||
} else {
|
||||
ORT_THROW("Reduction '", reduction, "' is not supported on CUDA and opset >= 13.");
|
||||
}
|
||||
}
|
||||
}
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
ScatterNDReduction reduction_{ScatterNDReduction::None};
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
15
onnxruntime/core/providers/cuda/tensor/scatter_nd_common.h
Normal file
15
onnxruntime/core/providers/cuda/tensor/scatter_nd_common.h
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
struct ElementCountsAndInputDimsSpanOrGpu {
|
||||
int64_t stack_ptr[12];
|
||||
int64_t* gpu_ptr;
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -14,7 +14,7 @@ __global__ void _ScatterNDKernel(
|
|||
const size_t num_indices,
|
||||
const int64_t* indices_data,
|
||||
const int64_t last_index_dimension,
|
||||
const int64_t* element_counts_and_input_dims,
|
||||
ElementCountsAndInputDimsSpanOrGpu element_counts_and_input_dims,
|
||||
const T* updates_data,
|
||||
const size_t num_updates_elements) {
|
||||
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, num_indices);
|
||||
|
|
@ -27,8 +27,12 @@ __global__ void _ScatterNDKernel(
|
|||
for (size_t i = indices_start; i < indices_end; ++i) {
|
||||
int64_t index = indices_data[i];
|
||||
|
||||
int64_t element_count_dim = element_counts_and_input_dims[i - indices_start];
|
||||
int64_t dim_value = element_counts_and_input_dims[i - indices_start + last_index_dimension];
|
||||
int64_t element_count_dim = element_counts_and_input_dims.gpu_ptr == nullptr
|
||||
? element_counts_and_input_dims.stack_ptr[i - indices_start]
|
||||
: element_counts_and_input_dims.gpu_ptr[i - indices_start];
|
||||
int64_t dim_value = element_counts_and_input_dims.gpu_ptr == nullptr
|
||||
? element_counts_and_input_dims.stack_ptr[i - indices_start + last_index_dimension]
|
||||
: element_counts_and_input_dims.gpu_ptr[i - indices_start + last_index_dimension];
|
||||
|
||||
// Clamp the index if out of range
|
||||
// This would have been an error in the CPU kernel, but throwing in the CUDA EP
|
||||
|
|
@ -66,7 +70,7 @@ Status ScatterNDImpl(
|
|||
const size_t num_indices,
|
||||
const int64_t* indices_data,
|
||||
const int64_t last_index_dimension,
|
||||
const int64_t* element_counts_and_input_dims,
|
||||
const ElementCountsAndInputDimsSpanOrGpu& element_counts_and_input_dims,
|
||||
const void* updates_data,
|
||||
const size_t num_updates_elements) {
|
||||
if (num_indices == 0)
|
||||
|
|
@ -128,5 +132,197 @@ Status ScatterNDImpl(
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
template <class T>
|
||||
struct FuncAdd {
|
||||
__device__ __inline__ void operator()(T* start_addr, T value) const {
|
||||
atomic_add(start_addr, value);
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct FuncMul {
|
||||
__device__ __inline__ void operator()(T* start_addr, T value) const {
|
||||
atomic_mul(start_addr, value);
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct FuncMax {
|
||||
__device__ __inline__ void operator()(T* start_addr, T value) const {
|
||||
atomic_max(start_addr, value);
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct FuncMin {
|
||||
__device__ __inline__ void operator()(T* start_addr, T value) const {
|
||||
atomic_min(start_addr, value);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename TFunc>
|
||||
__global__ void _ScatterNDKernelReduction(
|
||||
T* output_data,
|
||||
const size_t num_indices,
|
||||
const int64_t* indices_data,
|
||||
const int64_t last_index_dimension,
|
||||
ElementCountsAndInputDimsSpanOrGpu element_counts_and_input_dims,
|
||||
const T* updates_data,
|
||||
const size_t num_updates_elements,
|
||||
const TFunc func) {
|
||||
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, num_indices);
|
||||
|
||||
// Compute the base offset into the output data
|
||||
int64_t data_offset = 0;
|
||||
|
||||
size_t indices_start = last_index_dimension * id;
|
||||
size_t indices_end = indices_start + last_index_dimension;
|
||||
for (size_t i = indices_start; i < indices_end; ++i) {
|
||||
int64_t index = indices_data[i];
|
||||
|
||||
int64_t element_count_dim = element_counts_and_input_dims.gpu_ptr == nullptr
|
||||
? element_counts_and_input_dims.stack_ptr[i - indices_start]
|
||||
: element_counts_and_input_dims.gpu_ptr[i - indices_start];
|
||||
int64_t dim_value = element_counts_and_input_dims.gpu_ptr == nullptr
|
||||
? element_counts_and_input_dims.stack_ptr[i - indices_start + last_index_dimension]
|
||||
: element_counts_and_input_dims.gpu_ptr[i - indices_start + last_index_dimension];
|
||||
|
||||
// Clamp the index if out of range
|
||||
// This would have been an error in the CPU kernel, but throwing in the CUDA EP
|
||||
// is hard. This is the approach taken by other frameworks for out of bound indices
|
||||
// in their corresponding GPU backends as well.
|
||||
// index >= -dim_value && index < dim_value
|
||||
|
||||
if (index >= 0) {
|
||||
if (index >= dim_value) {
|
||||
index = dim_value - 1;
|
||||
}
|
||||
} else {
|
||||
if (index < -dim_value) {
|
||||
index = 0;
|
||||
} else {
|
||||
index += dim_value;
|
||||
}
|
||||
}
|
||||
|
||||
data_offset += (index * element_count_dim);
|
||||
}
|
||||
|
||||
const T* updates_data_base = updates_data + num_updates_elements * id;
|
||||
T* output_data_base = output_data + data_offset;
|
||||
|
||||
for (size_t i = 0; i < num_updates_elements; ++i) {
|
||||
func(output_data_base + i, updates_data_base[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status _ScatterNDType(
|
||||
cudaStream_t stream,
|
||||
T* output_data,
|
||||
const size_t num_indices,
|
||||
const int64_t* indices_data,
|
||||
const int64_t last_index_dimension,
|
||||
const ElementCountsAndInputDimsSpanOrGpu& element_counts_and_input_dims,
|
||||
const T* updates_data,
|
||||
const size_t num_updates_elements,
|
||||
ScatterNDReduction reduction) {
|
||||
// Parallelize on number of indices
|
||||
int blocksPerGrid = static_cast<int>(ceil(static_cast<float>(num_indices) / GridDim::maxThreadsPerBlock));
|
||||
|
||||
switch (reduction) {
|
||||
case ScatterNDReduction::Add:
|
||||
_ScatterNDKernelReduction<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
output_data,
|
||||
num_indices,
|
||||
indices_data,
|
||||
last_index_dimension,
|
||||
element_counts_and_input_dims,
|
||||
updates_data,
|
||||
num_updates_elements,
|
||||
FuncAdd<T>());
|
||||
break;
|
||||
case ScatterNDReduction::Mul:
|
||||
_ScatterNDKernelReduction<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
output_data,
|
||||
num_indices,
|
||||
indices_data,
|
||||
last_index_dimension,
|
||||
element_counts_and_input_dims,
|
||||
updates_data,
|
||||
num_updates_elements,
|
||||
FuncMul<T>());
|
||||
break;
|
||||
case ScatterNDReduction::Min:
|
||||
_ScatterNDKernelReduction<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
output_data,
|
||||
num_indices,
|
||||
indices_data,
|
||||
last_index_dimension,
|
||||
element_counts_and_input_dims,
|
||||
updates_data,
|
||||
num_updates_elements,
|
||||
FuncMin<T>());
|
||||
break;
|
||||
case ScatterNDReduction::Max:
|
||||
_ScatterNDKernelReduction<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
output_data,
|
||||
num_indices,
|
||||
indices_data,
|
||||
last_index_dimension,
|
||||
element_counts_and_input_dims,
|
||||
updates_data,
|
||||
num_updates_elements,
|
||||
FuncMax<T>());
|
||||
break;
|
||||
default:
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Reduction ", static_cast<int>(reduction), " not implemented for ScatterND operator.");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ScatterNDImplReduction(
|
||||
cudaStream_t stream,
|
||||
void* output_data,
|
||||
const int32_t element_type,
|
||||
const size_t num_indices,
|
||||
const int64_t* indices_data,
|
||||
const int64_t last_index_dimension,
|
||||
const ElementCountsAndInputDimsSpanOrGpu& element_counts_and_input_dims,
|
||||
const void* updates_data,
|
||||
const size_t num_updates_elements,
|
||||
ScatterNDReduction reduction) {
|
||||
if (num_indices == 0)
|
||||
return Status::OK();
|
||||
|
||||
switch (element_type) {
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
|
||||
return _ScatterNDType<float>(
|
||||
stream,
|
||||
reinterpret_cast<float*>(output_data),
|
||||
num_indices,
|
||||
indices_data,
|
||||
last_index_dimension,
|
||||
element_counts_and_input_dims,
|
||||
reinterpret_cast<const float*>(updates_data),
|
||||
num_updates_elements,
|
||||
reduction);
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
|
||||
return _ScatterNDType<half>(
|
||||
stream,
|
||||
reinterpret_cast<half*>(output_data),
|
||||
num_indices,
|
||||
indices_data,
|
||||
last_index_dimension,
|
||||
element_counts_and_input_dims,
|
||||
reinterpret_cast<const half*>(updates_data),
|
||||
num_updates_elements,
|
||||
reduction);
|
||||
default:
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "element_type ", static_cast<int>(element_type), " not implemented for ScatterND operator.");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@
|
|||
#pragma once
|
||||
|
||||
#include "core/providers/cuda/shared_inc/cuda_utils.h"
|
||||
#include "core/providers/cuda/tensor/scatter_nd_kind.h"
|
||||
#include "core/providers/cuda/tensor/scatter_nd_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
|
@ -15,9 +17,21 @@ Status ScatterNDImpl(
|
|||
const size_t num_indices,
|
||||
const int64_t* indices_data,
|
||||
const int64_t last_index_dimension,
|
||||
const int64_t* element_counts_and_input_dims,
|
||||
const ElementCountsAndInputDimsSpanOrGpu& element_counts_and_input_dims,
|
||||
const void* updates_data,
|
||||
const size_t num_updates_elements);
|
||||
|
||||
Status ScatterNDImplReduction(
|
||||
cudaStream_t stream,
|
||||
void* output_data,
|
||||
const int32_t element_type,
|
||||
const size_t num_indices,
|
||||
const int64_t* indices_data,
|
||||
const int64_t last_index_dimension,
|
||||
const ElementCountsAndInputDimsSpanOrGpu& element_counts_and_input_dims,
|
||||
const void* updates_data,
|
||||
const size_t num_updates_elements,
|
||||
ScatterNDReduction reduction);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
18
onnxruntime/core/providers/cuda/tensor/scatter_nd_kind.h
Normal file
18
onnxruntime/core/providers/cuda/tensor/scatter_nd_kind.h
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
enum class ScatterNDReduction : int {
|
||||
None = 0,
|
||||
Add = 1,
|
||||
Mul = 2,
|
||||
Min = 3,
|
||||
Max = 4,
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -1161,7 +1161,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, LRN);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, LRN);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, Identity);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, ScatterND);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 15, ScatterND);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, float, Pad);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, double, Pad);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, MLFloat16, Pad);
|
||||
|
|
@ -1295,6 +1295,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, double, LessOrEqual);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, 17, ScatterElements);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, 17, ScatterND);
|
||||
|
||||
// Opset 17
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, float, LayerNormalization);
|
||||
|
|
@ -1308,6 +1309,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, bool, Pad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, ScatterElements);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, ScatterND);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, Resize);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, Resize);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, Resize);
|
||||
|
|
@ -2115,7 +2117,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, LRN)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, LRN)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, Identity)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, ScatterND)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 15, ScatterND)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, float, Pad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, double, Pad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, MLFloat16, Pad)>,
|
||||
|
|
@ -2249,6 +2251,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, double, LessOrEqual)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, 17, ScatterElements)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, 17, ScatterND)>,
|
||||
|
||||
// Opset 17
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, float, LayerNormalization)>,
|
||||
|
|
@ -2262,6 +2265,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, bool, Pad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, ScatterElements)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, ScatterND)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18,
|
||||
float, Resize)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18,
|
||||
|
|
|
|||
|
|
@ -180,5 +180,60 @@ TEST(ScatterNDOpTest, ScatterND_batched_3tensor_int64) {
|
|||
test3.Run();
|
||||
}
|
||||
|
||||
TEST(ScatterNDOpTest, ScatterND_18_add) {
|
||||
OpTester test1("ScatterND", 18);
|
||||
test1.AddAttribute("reduction", "add");
|
||||
test1.AddInput<float>("data", {2, 2, 3}, {0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f});
|
||||
test1.AddInput<int64_t>("indices", {3, 1}, {0, 1, 0});
|
||||
// The linter complains if the line is split into multiple lines.
|
||||
test1.AddInput<float>("updates", {3, 2, 3}, {2.0f, 4.0f, 8.0f, 16.0f, 32.0f, 64.0f, 128.0f, 256.0f, 512.0f, 1024.0f, 2048.0f, 4096.0f, 8192.0f, 16384.0f, 32768.0f, 65536.0f, 131072.0f, 262144.0f});
|
||||
test1.AddOutput<float>("output", {2, 2, 3}, {8194.1f, 16388.1f, 32776.10f, 65552.10f, 131104.1f, 262208.1f, 128.1f, 256.1f, 512.1f, 1024.1f, 2048.1f, 4096.1f});
|
||||
test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(ScatterNDOpTest, ScatterND_18_mul) {
|
||||
OpTester test1("ScatterND", 18);
|
||||
test1.AddAttribute("reduction", "mul");
|
||||
test1.AddInput<float>("data", {2, 2, 3}, {0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f});
|
||||
test1.AddInput<int64_t>("indices", {3, 1}, {0, 1, 0});
|
||||
// The linter complains if the line is split into multiple lines.
|
||||
test1.AddInput<float>("updates", {3, 2, 3}, {2.0f, 4.0f, 8.0f, 16.0f, 32.0f, 64.0f, 128.0f, 256.0f, 512.0f, 1024.0f, 2048.0f, 4096.0f, 8192.0f, 16384.0f, 32768.0f, 65536.0f, 131072.0f, 262144.0f});
|
||||
test1.AddOutput<float>("output", {2, 2, 3}, {1638.4f, 6553.6f, 26214.4f, 104857.6f, 419430.4f, 1677721.625f, 12.8f, 25.6f, 51.2f, 102.4f, 204.8f, 409.6f});
|
||||
test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(ScatterNDOpTest, ScatterND_18_mul_long_shape) {
|
||||
OpTester test1("ScatterND", 18);
|
||||
test1.AddAttribute("reduction", "mul");
|
||||
test1.AddInput<float>("data", {2, 2, 3, 1, 1, 1, 1}, {0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f});
|
||||
test1.AddInput<int64_t>("indices", {3, 1}, {0, 1, 0});
|
||||
// The linter complains if the line is split into multiple lines.
|
||||
test1.AddInput<float>("updates", {3, 2, 3, 1, 1, 1, 1}, {2.0f, 4.0f, 8.0f, 16.0f, 32.0f, 64.0f, 128.0f, 256.0f, 512.0f, 1024.0f, 2048.0f, 4096.0f, 8192.0f, 16384.0f, 32768.0f, 65536.0f, 131072.0f, 262144.0f});
|
||||
test1.AddOutput<float>("output", {2, 2, 3, 1, 1, 1, 1}, {1638.4f, 6553.6f, 26214.4f, 104857.6f, 419430.4f, 1677721.625f, 12.8f, 25.6f, 51.2f, 102.4f, 204.8f, 409.6f});
|
||||
test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(ScatterNDOpTest, ScatterND_18_min) {
|
||||
OpTester test1("ScatterND", 18);
|
||||
test1.AddAttribute("reduction", "min");
|
||||
test1.AddInput<float>("data", {2, 2, 3}, {0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f});
|
||||
test1.AddInput<int64_t>("indices", {3, 1}, {0, 1, 0});
|
||||
// The linter complains if the line is split into multiple lines.
|
||||
test1.AddInput<float>("updates", {3, 2, 3}, {2.0f, 4.0f, 8.0f, 16.0f, 32.0f, 64.0f, 128.0f, 256.0f, 512.0f, 1024.0f, 2048.0f, 4096.0f, 8192.0f, 16384.0f, 32768.0f, 65536.0f, 131072.0f, 262144.0f});
|
||||
test1.AddOutput<float>("output", {2, 2, 3}, {0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f});
|
||||
test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(ScatterNDOpTest, ScatterND_18_max) {
|
||||
OpTester test1("ScatterND", 18);
|
||||
test1.AddAttribute("reduction", "max");
|
||||
test1.AddInput<float>("data", {2, 2, 3}, {0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f});
|
||||
test1.AddInput<int64_t>("indices", {3, 1}, {0, 1, 0});
|
||||
// The linter complains if the line is split into multiple lines.
|
||||
test1.AddInput<float>("updates", {3, 2, 3}, {2.0f, 4.0f, 8.0f, 16.0f, 32.0f, 64.0f, 128.0f, 256.0f, 512.0f, 1024.0f, 2048.0f, 4096.0f, 8192.0f, 16384.0f, 32768.0f, 65536.0f, 131072.0f, 262144.0f});
|
||||
test1.AddOutput<float>("output", {2, 2, 3}, {8192.0, 16384.0, 32768.0, 65536.0, 131072.0, 262144.0, 128.0, 256.0, 512.0, 1024.0, 2048.0, 4096.0});
|
||||
test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -308,14 +308,31 @@ TEST(ScatterElements, AddReduction) {
|
|||
test.AddAttribute<int64_t>("axis", 0);
|
||||
test.AddAttribute<std::string>("reduction", "add");
|
||||
|
||||
test.AddInput<float>("data", {2, 3}, {-9.f, -4.f, -1.f, -7.f, -3.f, -6.f});
|
||||
test.AddInput<int64_t>("indices", {4, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
|
||||
test.AddInput<float>("updates", {4, 3}, {1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f});
|
||||
test.AddOutput<float>("y", {2, 3}, {-9.f, -4.f, -1.f, -7.f + (1.f + 2.f + 3.f + 4.f), -3.f + (1.f + 2.f + 3.f + 4.f), -6.f + (1.f + 2.f + 3.f + 4.f)});
|
||||
test.AddInput<float>("data", {3, 3}, {1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f});
|
||||
test.AddInput<int64_t>("indices", {2, 3}, {1, 0, 2, 0, 2, 1});
|
||||
test.AddInput<float>("updates", {2, 3}, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f});
|
||||
test.AddOutput<float>("y", {3, 3}, {3.0f, 1.1f, 0.0f, 1.0f, 0.0f, 2.2f, 0.0f, 2.1f, 1.2f});
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
|
||||
}
|
||||
|
||||
#if defined(CUDA_VERSION)
|
||||
// Operation on float16 (MLFloat16) is not implemented on CPU.
|
||||
TEST(ScatterElements, AddReduction_MLFloat16) {
|
||||
OpTester test("ScatterElements", 18);
|
||||
test.AddAttribute<int64_t>("axis", 0);
|
||||
test.AddAttribute<std::string>("reduction", "add");
|
||||
|
||||
test.AddInput<MLFloat16>("data", {2, 3}, ToFloat16(std::vector<float>({-9.f, -4.f, -1.f, -7.f, -3.f, -6.f})));
|
||||
test.AddInput<int64_t>("indices", {4, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
|
||||
test.AddInput<MLFloat16>("updates", {4, 3}, ToFloat16(std::vector<float>({1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f})));
|
||||
test.AddOutput<MLFloat16>("y", {2, 3}, ToFloat16(std::vector<float>({-9.f, -4.f, -1.f, -7.f + (1.f + 2.f + 3.f + 4.f), -3.f + (1.f + 2.f + 3.f + 4.f), -6.f + (1.f + 2.f + 3.f + 4.f)})));
|
||||
|
||||
// exclude CPU Execution Provider as MLFloat16 is not supported in CPU
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST(ScatterElements, AddReductionAxis1) {
|
||||
OpTester test("ScatterElements", 18);
|
||||
test.AddAttribute<int64_t>("axis", 1);
|
||||
|
|
|
|||
|
|
@ -89,14 +89,16 @@ def apply_filters(filters, category):
|
|||
|
||||
def load_jsonc(basename: str):
|
||||
"""Returns a deserialized object from the JSONC file in testdata/<basename>."""
|
||||
filename = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)),
|
||||
"testdata",
|
||||
basename,
|
||||
)
|
||||
if not os.path.exists(filename):
|
||||
raise FileNotFoundError(f"File not found {filename!r}.")
|
||||
filenames = [
|
||||
os.path.join(os.path.dirname(os.path.realpath(__file__)), "testdata", basename),
|
||||
os.path.realpath(os.path.join(os.path.dirname(__file__), "..", "..", "test", "testdata", basename)),
|
||||
]
|
||||
|
||||
filtered = [f for f in filenames if os.path.exists(f)]
|
||||
if not filtered:
|
||||
raise FileNotFoundError(f"No file found in {filenames!r}.")
|
||||
|
||||
filename = filtered[0]
|
||||
with open(filename, encoding="utf-8") as f: # pylint: disable=invalid-name
|
||||
lines = f.readlines()
|
||||
lines = [x.split("//")[0] for x in lines]
|
||||
|
|
|
|||
329
onnxruntime/test/python/onnxruntime_test_scatternd.py
Normal file
329
onnxruntime/test/python/onnxruntime_test_scatternd.py
Normal file
|
|
@ -0,0 +1,329 @@
|
|||
import itertools
|
||||
import json
|
||||
import os
|
||||
import typing
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import onnx.helper as oh
|
||||
from onnx import TensorProto, load
|
||||
from onnx.numpy_helper import from_array
|
||||
from onnx.reference import ReferenceEvaluator
|
||||
|
||||
import onnxruntime
|
||||
|
||||
|
||||
def has_cuda():
|
||||
available_providers = [provider for provider in onnxruntime.get_available_providers()]
|
||||
return "CUDAExecutionProvider" in available_providers
|
||||
|
||||
|
||||
def ignore_warnings(warns: typing.List[Warning]) -> typing.Callable:
|
||||
def wrapper(fct):
|
||||
if warns is None:
|
||||
raise AssertionError(f"warns cannot be None for '{fct}'.")
|
||||
|
||||
def call_f(self):
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", warns)
|
||||
return fct(self)
|
||||
|
||||
return call_f
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class TestScatterPerProvider(unittest.TestCase):
|
||||
def assert_exists(self, filename: str):
|
||||
assert os.path.exists(filename), f"Unable to find {filename!r}."
|
||||
|
||||
def common_scatter(self, opset, providers, dtype, reduction, expected_names):
|
||||
from onnxruntime import InferenceSession, SessionOptions
|
||||
|
||||
op_type = "ScatterElements" if "ScatterElements" in expected_names else "ScatterND"
|
||||
ndim = 2 if op_type == "ScatterElements" else 3
|
||||
|
||||
assert dtype in (np.float16, np.float32)
|
||||
itype = TensorProto.FLOAT if dtype == np.float32 else TensorProto.FLOAT16
|
||||
model = oh.make_model(
|
||||
oh.make_graph(
|
||||
[
|
||||
oh.make_node("CastLike", ["X", "I"], ["data"]),
|
||||
oh.make_node(
|
||||
op_type,
|
||||
inputs=["data", "indices", "updates"],
|
||||
outputs=["sy"],
|
||||
# axis=0,
|
||||
reduction=reduction,
|
||||
),
|
||||
oh.make_node("Sub", ["sy", "I"], ["Y"]),
|
||||
],
|
||||
"name",
|
||||
[
|
||||
oh.make_tensor_value_info("X", TensorProto.FLOAT, [None] * ndim),
|
||||
oh.make_tensor_value_info("indices", TensorProto.INT64, [None, None]),
|
||||
oh.make_tensor_value_info("updates", itype, [None] * ndim),
|
||||
],
|
||||
[oh.make_tensor_value_info("Y", itype, [None] * ndim)],
|
||||
[from_array(np.array([0], dtype=dtype), name="I")],
|
||||
),
|
||||
opset_imports=[oh.make_opsetid("", opset)],
|
||||
ir_version=8 if opset <= 18 else 9,
|
||||
)
|
||||
|
||||
if not os.path.exists("temp_dump"):
|
||||
os.mkdir("temp_dump")
|
||||
for name in os.listdir("temp_dump"):
|
||||
os.remove(os.path.join("temp_dump", name))
|
||||
|
||||
filename = f"temp_dump/{op_type}_{providers[0]}_{itype}.onnx"
|
||||
opts = SessionOptions()
|
||||
opts.optimized_model_filepath = filename
|
||||
sess = InferenceSession(model.SerializeToString(), opts, providers=providers)
|
||||
self.assertTrue(sess is not None)
|
||||
self.assert_exists(filename)
|
||||
onx = load(filename)
|
||||
names = [n.op_type for n in onx.graph.node]
|
||||
self.assertEqual(expected_names, names)
|
||||
|
||||
sonx = str(onx).replace(" ", "").replace("\n", "|")
|
||||
sexp = 'op_type:"Cast"|attribute{|name:"to"|type:INT|i:%d|}' % itype
|
||||
sexp2 = 'op_type:"Cast"|attribute{|name:"to"|i:%d|type:INT|}' % itype
|
||||
assert sexp in sonx or sexp2 in sonx, f"Unable to find a substring in {sonx!r}"
|
||||
if providers == ["CPUExecutionProvider"]:
|
||||
return
|
||||
|
||||
if op_type == "ScatterElements":
|
||||
data = np.zeros((3, 3), dtype=np.float32)
|
||||
data[0, 0] = 1
|
||||
indices = np.array([[1, 0, 2], [0, 2, 1]], dtype=np.int64)
|
||||
updates = np.array([[1.0, 1.1, 1.2], [2.0, 2.1, 2.2]], dtype=dtype)
|
||||
else:
|
||||
data = np.array(
|
||||
[
|
||||
[[1, 2, 3, 4], [5, 6, 7, 8], [8, 7, 6, 5], [4, 3, 2, 1]],
|
||||
[[1, 2, 3, 4], [5, 6, 7, 8], [8, 7, 6, 5], [4, 3, 2, 1]],
|
||||
[[8, 7, 6, 5], [4, 3, 2, 1], [1, 2, 3, 4], [5, 6, 7, 8]],
|
||||
[[8, 7, 6, 5], [4, 3, 2, 1], [1, 2, 3, 4], [5, 6, 7, 8]],
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
indices = np.array([[0], [2]], dtype=np.int64)
|
||||
updates = np.array(
|
||||
[
|
||||
[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
|
||||
[[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]],
|
||||
],
|
||||
dtype=dtype,
|
||||
)
|
||||
opts = SessionOptions()
|
||||
opts.enable_profiling = True
|
||||
opts.optimized_model_filepath = filename
|
||||
sess = InferenceSession(model.SerializeToString(), opts, providers=providers)
|
||||
got = sess.run(None, {"X": data, "indices": indices, "updates": updates})[0]
|
||||
self.assertEqual(got.dtype, updates.dtype)
|
||||
prof = sess.end_profiling()
|
||||
|
||||
with open(prof, "r") as f: # noqa: UP015
|
||||
content = f.read()
|
||||
js = json.loads(content)
|
||||
|
||||
exe_providers = []
|
||||
suffixes = ["_kernel_time", "_fence_before", "_fence_after"]
|
||||
rows = []
|
||||
for row in js:
|
||||
if "args" in row and isinstance(row["args"], dict):
|
||||
for k, v in row["args"].items():
|
||||
row[f"args_{k}"] = v
|
||||
del row["args"]
|
||||
name = row["name"]
|
||||
for suf in suffixes:
|
||||
if name.endswith(suf):
|
||||
changed = name[: -len(suf)]
|
||||
row["op_name"] = changed
|
||||
break
|
||||
rows.append(row)
|
||||
exe_providers.append((row.get("args_provider", None), row.get("args_op_name", None)))
|
||||
short_list = [(a, b) for a, b in exe_providers if a is not None and b is not None]
|
||||
self.assertEqual(short_list, [("CUDAExecutionProvider", o) for o in expected_names])
|
||||
|
||||
@unittest.skipIf(not has_cuda(), reason="cuda not available")
|
||||
@ignore_warnings(DeprecationWarning)
|
||||
def test_scatterels_cuda(self):
|
||||
default_value = [
|
||||
"Cast",
|
||||
"ScatterElements",
|
||||
"Sub",
|
||||
]
|
||||
expected = {
|
||||
(np.float32, "none"): default_value,
|
||||
(np.float16, "none"): default_value,
|
||||
(np.float32, "add"): default_value,
|
||||
(np.float16, "add"): default_value,
|
||||
(np.float32, "mul"): default_value,
|
||||
(np.float16, "mul"): default_value,
|
||||
(np.float32, "min"): default_value,
|
||||
(np.float16, "min"): default_value,
|
||||
(np.float32, "max"): default_value,
|
||||
(np.float16, "max"): default_value,
|
||||
}
|
||||
for opset, dtype, reduction in itertools.product(
|
||||
[16, 18], [np.float32, np.float16], ["none", "add", "mul", "min", "max"]
|
||||
):
|
||||
with self.subTest(dtype=dtype, reduction=reduction, opset=opset):
|
||||
self.common_scatter(
|
||||
opset,
|
||||
["CUDAExecutionProvider"],
|
||||
dtype,
|
||||
reduction,
|
||||
expected[dtype, reduction],
|
||||
)
|
||||
|
||||
@unittest.skipIf(not has_cuda(), reason="cuda not available")
|
||||
@ignore_warnings(DeprecationWarning)
|
||||
def test_scatternd_cuda(self):
|
||||
default_value = [
|
||||
"Cast",
|
||||
"ScatterND",
|
||||
"Sub",
|
||||
]
|
||||
expected = {
|
||||
(np.float32, "none"): default_value,
|
||||
(np.float16, "none"): default_value,
|
||||
(np.float32, "add"): default_value,
|
||||
(np.float16, "add"): default_value,
|
||||
(np.float32, "mul"): default_value,
|
||||
(np.float16, "mul"): default_value,
|
||||
(np.float32, "min"): default_value,
|
||||
(np.float16, "min"): default_value,
|
||||
(np.float32, "max"): default_value,
|
||||
(np.float16, "max"): default_value,
|
||||
}
|
||||
for opset, dtype, reduction in itertools.product(
|
||||
[16, 18], [np.float32, np.float16], ["none", "add", "mul", "min", "max"]
|
||||
):
|
||||
with self.subTest(dtype=dtype, reduction=reduction, opset=opset):
|
||||
self.common_scatter(
|
||||
opset,
|
||||
["CUDAExecutionProvider"],
|
||||
dtype,
|
||||
reduction,
|
||||
expected[dtype, reduction],
|
||||
)
|
||||
|
||||
@ignore_warnings(DeprecationWarning)
|
||||
def test_scatterels_cpu(self):
|
||||
default_value = [
|
||||
"Cast",
|
||||
"ScatterElements",
|
||||
"Sub",
|
||||
]
|
||||
expected = {
|
||||
(np.float32, "none"): default_value,
|
||||
(np.float16, "none"): default_value,
|
||||
(np.float32, "add"): default_value,
|
||||
(np.float16, "add"): default_value,
|
||||
(np.float32, "mul"): default_value,
|
||||
(np.float16, "mul"): default_value,
|
||||
(np.float32, "min"): default_value,
|
||||
(np.float16, "min"): default_value,
|
||||
(np.float32, "max"): default_value,
|
||||
(np.float16, "max"): default_value,
|
||||
}
|
||||
for opset, dtype, reduction in itertools.product([16, 18], [np.float32], ["none", "add", "mul", "min", "max"]):
|
||||
with self.subTest(dtype=dtype, reduction=reduction, opset=opset):
|
||||
self.common_scatter(
|
||||
opset,
|
||||
["CPUExecutionProvider"],
|
||||
dtype,
|
||||
reduction,
|
||||
expected[dtype, reduction],
|
||||
)
|
||||
|
||||
@ignore_warnings(DeprecationWarning)
|
||||
def test_scatternd_cpu(self):
|
||||
default_value = [
|
||||
"Cast",
|
||||
"ScatterND",
|
||||
"Sub",
|
||||
]
|
||||
expected = {
|
||||
(np.float32, "none"): default_value,
|
||||
(np.float16, "none"): default_value,
|
||||
(np.float32, "add"): default_value,
|
||||
(np.float16, "add"): default_value,
|
||||
(np.float32, "mul"): default_value,
|
||||
(np.float16, "mul"): default_value,
|
||||
(np.float32, "min"): default_value,
|
||||
(np.float16, "min"): default_value,
|
||||
(np.float32, "max"): default_value,
|
||||
(np.float16, "max"): default_value,
|
||||
}
|
||||
for opset, dtype, reduction in itertools.product([16, 18], [np.float32], ["none", "add", "mul", "min", "max"]):
|
||||
with self.subTest(dtype=dtype, reduction=reduction, opset=opset):
|
||||
self.common_scatter(
|
||||
opset,
|
||||
["CPUExecutionProvider"],
|
||||
dtype,
|
||||
reduction,
|
||||
expected[dtype, reduction],
|
||||
)
|
||||
|
||||
def _scatternd_standalone_cuda(self, reduction, line):
|
||||
model = oh.make_model(
|
||||
oh.make_graph(
|
||||
[
|
||||
oh.make_node(
|
||||
"ScatterND",
|
||||
inputs=["data", "indices", "updates"],
|
||||
outputs=["y"],
|
||||
reduction=reduction,
|
||||
)
|
||||
],
|
||||
"nd",
|
||||
[
|
||||
oh.make_tensor_value_info("data", TensorProto.FLOAT, [None, None, None]),
|
||||
oh.make_tensor_value_info("indices", TensorProto.INT64, [None, None]),
|
||||
oh.make_tensor_value_info("updates", TensorProto.FLOAT, [None, None, None]),
|
||||
],
|
||||
[oh.make_tensor_value_info("y", TensorProto.FLOAT, [None, None, None])],
|
||||
),
|
||||
opset_imports=[oh.make_opsetid("", 18)],
|
||||
ir_version=9,
|
||||
)
|
||||
|
||||
data = np.full((2, 2, 3), 0.1, dtype=np.float32)
|
||||
indices = np.array([[line], [1 - line], [line]], dtype=np.int64)
|
||||
updates = (2 ** (np.arange(18) + 1).astype(np.float32).reshape((3, 2, 3))).astype(np.float32)
|
||||
|
||||
feeds = dict(data=data, indices=indices, updates=updates)
|
||||
ref = ReferenceEvaluator(model)
|
||||
expected = ref.run(None, feeds)[0]
|
||||
|
||||
providers = (
|
||||
[
|
||||
["CUDAExecutionProvider"],
|
||||
["CPUExecutionProvider"],
|
||||
]
|
||||
if has_cuda()
|
||||
else [["CPUExecutionProvider"]]
|
||||
)
|
||||
for provider in providers:
|
||||
sess = onnxruntime.InferenceSession(model.SerializeToString(), providers=provider)
|
||||
got = sess.run(None, feeds)[0]
|
||||
self.assertEqual(expected.tolist(), got.tolist())
|
||||
|
||||
def test_scatternd_standalone_cuda(self):
|
||||
self._scatternd_standalone_cuda("add", 0)
|
||||
self._scatternd_standalone_cuda("add", 1)
|
||||
self._scatternd_standalone_cuda("mul", 0)
|
||||
self._scatternd_standalone_cuda("mul", 1)
|
||||
self._scatternd_standalone_cuda("min", 0)
|
||||
self._scatternd_standalone_cuda("min", 1)
|
||||
self._scatternd_standalone_cuda("max", 0)
|
||||
self._scatternd_standalone_cuda("max", 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
|
|
@ -19,9 +19,13 @@ class TestDynamicQuantizationSubgraph(unittest.TestCase):
|
|||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
onnx_path = os.path.join(tmpdir, "decoder_model_merged.onnx")
|
||||
quantized_onnx_path = os.path.join(tmpdir, "decoder_model_merged_quantized.onnx")
|
||||
urllib.request.urlretrieve(
|
||||
"https://huggingface.co/fxmarty/t5-tiny-onnx-testing/resolve/main/decoder_model_merged.onnx", onnx_path
|
||||
)
|
||||
url = "https://huggingface.co/fxmarty/t5-tiny-onnx-testing/resolve/main/decoder_model_merged.onnx"
|
||||
try:
|
||||
urllib.request.urlretrieve(url, onnx_path)
|
||||
except urllib.request.HTTPError as e:
|
||||
# The unit test should not fail for this kind of issue.
|
||||
# TODO: use another way to retrieve the model.
|
||||
raise unittest.SkipTest(f"Unable to fetch {url!r} due to {e}") # noqa: B904
|
||||
|
||||
quantize_dynamic(
|
||||
model_input=onnx_path,
|
||||
|
|
@ -62,3 +66,7 @@ class TestDynamicQuantizationSubgraph(unittest.TestCase):
|
|||
if attr.type == onnx.AttributeProto.GRAPH:
|
||||
for initializer in attr.g.initializer:
|
||||
self.assertTrue("shared.weight" not in initializer.name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
|
|
|
|||
Loading…
Reference in a new issue