Update ScatterElements to Support Opset 13, 15, 18 (#19198)

`ScatterElements` in opset 18 has been around for a while. However, the
highest opset supporting `ScatterElements` in ORT is 13. This PR
implement this op in CUDA EP by replacing `assignment` in the current
CDUA kernel with `atomic reduction` (e.g., atomic add, atomic max). A
series of fundamental atomic functions (e.g., atomic max for int8_t and
half) are implemented in `common.cuh`; the implementation is general
enough to cover old CUDA and new CUDA versions.

- The core changes are in `cuda/atomic/common.cuh` with very detailed
documentation including `bit-wise operation's visualization`. They are
also copied to `rocm/atomic/common.cuh` to support AMD GPU.
- `/cuda/tensor/gather_elements_impl.cu` contains small changes to call
the new atomic functions to support new `reduction` behavior in new
`ScatterElements`.
- New `ScatterElements` are defined in `rocm_execution_provider.cc` and
`cuda_execution_provider.cc`.
This commit is contained in:
Wei-Sheng Chin 2024-01-30 09:18:50 -08:00 committed by GitHub
parent 3e17ca3dab
commit ffc3431a66
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 858 additions and 25 deletions

View file

@ -744,7 +744,9 @@ Do not modify directly.*
|||[9, 10]|**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||8|**I** = tensor(int64)<br/> **V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Scatter|*in* data:**T**<br> *in* indices:**Tind**<br> *in* updates:**T**<br> *out* output:**T**|[9, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|ScatterElements|*in* data:**T**<br> *in* indices:**Tind**<br> *in* updates:**T**<br> *out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|ScatterElements|*in* data:**T**<br> *in* indices:**Tind**<br> *in* updates:**T**<br> *out* output:**T**|18+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|||[16, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|||[13, 15]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|ScatterND|*in* data:**T**<br> *in* indices:**tensor(int64)**<br> *in* updates:**T**<br> *out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|

View file

@ -198,13 +198,6 @@ struct Func_Min<std::string> {
}
};
template <>
struct Func_Min<MLFloat16> {
void operator()(MLFloat16*, const MLFloat16*) const {
ORT_NOT_IMPLEMENTED("CPU execution provider: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'min'.");
}
};
template <>
struct Func_Min<BFloat16> {
void operator()(BFloat16*, const BFloat16*) const {
@ -233,13 +226,6 @@ struct Func_Max<std::string> {
}
};
template <>
struct Func_Max<MLFloat16> {
void operator()(MLFloat16*, const MLFloat16*) const {
ORT_NOT_IMPLEMENTED("CPU execution provider: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'max'.");
}
};
template <>
struct Func_Max<BFloat16> {
void operator()(BFloat16*, const BFloat16*) const {

View file

@ -122,5 +122,316 @@ __device__ __forceinline__ void AtomicAdd<half>(half* start_addr, size_t index,
#endif
}
// Disable default template instantiation.
// For every type T, we need to define a specialization
// to select the right type for calling atomicCAS.
template <typename T>
class AtomicCasType;
template<>
class AtomicCasType<int8_t> {
public:
using type = unsigned short int;
static const unsigned int mask = 0xffu;
};
template<>
class AtomicCasType<half> {
public:
using type = unsigned short int;
static const unsigned int mask = 0xffffu;
};
template<>
class AtomicCasType<float> {
public:
using type = unsigned int;
static const unsigned int mask = 0xffffffffu;
};
template<>
class AtomicCasType<double> {
public:
using type = unsigned long long int;
static const unsigned int mask = 0xffffffffu;
};
template<>
class AtomicCasType<int> {
public:
using type = int;
static const unsigned int mask = 0xffffffffu;
};
template<>
class AtomicCasType<int64_t> {
public:
using type = unsigned long long int;
static const unsigned int mask = 0xffffffffu;
};
// Obtained from pytorch/aten/src/ATen/cuda/Atomic.cuh.
//
// This function compute 8-bit atomic binary operation using 32-bit atomicCAS.
// It accumulate `val` into the `address` using the `func`.
// The accumulation is atomic (i.e., thread-safe).
//
// E.g., Assume ValueType is
// int8_t
// and BinaryFunc is
// struct AddFunc {
// __device__ __forceinline__ int8_t operator()(int8_t a, int8_t b) const {
// return a + b;
// }
// This function becomes atomic_add for int8_t.
template<typename ValueType, typename BinaryFunc>
__device__ __forceinline__ void atomic_byte_func_with_unit32_cas(ValueType* address, ValueType val, BinaryFunc func) {
// Assert to ensure the following bit-wise manipulation is correct.
static_assert(sizeof(ValueType) == 1 | sizeof(ValueType) == 2 | sizeof(ValueType) == 4,
"ValueType must be 1-byte, 2-byte or 4-byte large.");
// Number of bytes to the lower 4-byte aligned address.
// If the current address is b1010"10", then offset = b10 = 2,
// which means the current address is 2 bytes away from
// the lower 4-byte aligned address b1010"00".
size_t offset = (size_t)address & 3;
// Find an new 4-byte aligned address `address_as_ui` lower than
// or equal to `address`. Lower than `address` so that the actual
// int8_t byte is in the 4-byte word that we load.
//
// This address has the following properties:
// 1. It is 4-byte aligned.
// 2. It is lower than or equal to `address`.
// 3. De-referencing this address may return
// a uint32_t value that contains the same int8_t
// value indicated by `address`.
//
// E.g.,
// address = b101010
// offset = b101010 & b000011 = b10 = 2
// (char*)address - offset => (char*)b101010 - b000010 => b1010"00",
// which is (32-bit aligned).
uint32_t * address_as_ui = (uint32_t*)((char*)address - offset);
uint32_t old = *address_as_ui;
// E.g., offset = 2.
// address_as_ui is an address 2 bytes lower than `address`.
//
// ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 .....
// ^ ^ ^
// | | |
// | address <--- offset * 8 (bit)-----> address_as_ui
// | ^
// | |
// ------------------------- *address_as_ui -----------------------
//
// This visualization shows
// 1. the 32-bit word at address_as_ui.
// 2. the gap between address_as_ui and address.
// 3. *address_as_ui contains the int8_t value at `address`.
uint32_t shift = offset * 8;
uint32_t old_byte;
uint32_t newval;
uint32_t assumed;
do {
assumed = old;
// Select 8-bit value from 32-bit word. Assume offset = 2 (byte), so
// we want to select the 3rd byte (byte 2 below) from the word.
//
// Journey of a 32-bit value:
//
// ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 .....
//
// |
// | old >> offset * 8, where offset = 2.
// | Effectively, push lower two bytes
// | out of the word.
// V
//
// 00000000 | 00000000 | ..... byte 3 ..... | ..... byte 2 .....
//
// | apply bit-wise AND,
// | & 0xff (i.e., & b11111111),
// | so that we only keep
// | the byte of interest.
// | Otherwise, overflow may
// | happen when casting this
// | 32-bit value to int8_t.
// V
//
// 00000000 | 00000000 | 00000000 | ..... byte 2 .....
old_byte = (old >> shift) & AtomicCasType<ValueType>::mask;
// Compute new int8_t value and store it to newrawvalue.
// Journey of a 32-bit value (cont'd):
//
// newrawvalue
// ... new byte 2 ...
auto newrawvalue = func(val, reinterpret_cast<ValueType&>(old_byte));
// Put the new int8_t value back to 32-bit word.
// Also ensure that bits not occupied by the int8_t value are 0s.
//
// Journey of a 32-bit value (cont'd):
//
// reinterpret_cast<uint32_t&>(newrawvalue)
// random values | random values | random values | ... new byte 2 ...
//
// reinterpret_cast<uint32_t&>(newrawvalue) & AtomicCasType<ValueType>::mask
// 00000000 | 00000000 | 00000000 | ... new byte 2 ...
newval = reinterpret_cast<uint32_t&>(newrawvalue) & AtomicCasType<ValueType>::mask;
// Journey of a 32-bit value (cont'd):
//
// old
// ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 .....
//
// 0x000000ff
// 00000000 | 00000000 | 00000000 | 11111111
//
// 0x000000ff << shift
// 00000000 | 11111111 | 00000000 | 00000000
//
// ~(0x000000ff << shift)
// 11111111 | 00000000 | 11111111 | 11111111
//
// old & ~(0x000000ff << shift)
// ..... byte 3 ..... | 00000000 | ..... byte 1 ..... | ..... byte 0 .....
//
// newval << shift
// 00000000 | ... new byte 2 ... | 00000000 | 00000000
//
// (old & ~(0x000000ff << shift)) | (newval << shift)
// ..... byte 3 ..... | ... new byte 2 ... | ..... byte 1 ..... | ..... byte 0 .....
newval = (old & ~(AtomicCasType<ValueType>::mask << shift)) | (newval << shift);
old = atomicCAS(address_as_ui, assumed, newval);
} while (assumed != old);
}
// It accumulates `val` into the `address` using the `func`.
// This function is thread-safe (i.e., atomic).
template<typename ValueType, typename BinaryFunc>
__device__ __forceinline__ void atomic_binary_func(ValueType* address, ValueType val, BinaryFunc func) {
ValueType observed = *address, assumed, new_value;
using CasType = typename AtomicCasType<ValueType>::type;
static_assert(sizeof(ValueType) == sizeof(CasType),
"ValueType and CasType must have the same size for calling atomicCAS.");
auto address_as_cas_type = reinterpret_cast<CasType*>(address);
do {
// Record the value used to compute new value.
assumed = observed;
// Compute expected new value.
new_value = func(observed, val);
// Cast to aribitrary 2-byte type to desired integer type supported by atomicCAS.
// 4
// 8
auto observed_as_cas_type = *reinterpret_cast<CasType*>(&observed);
auto new_value_as_cas_type = *reinterpret_cast<CasType*>(&new_value);
// Call atomicCAS as if the 2-byte type variables are all unsigned short int.
// 4 unsigned int (or int)
// 8 unsigned long long int
auto cas_observed_as_cas_type = atomicCAS(address_as_cas_type, observed_as_cas_type, new_value_as_cas_type);
// Cast the freshly observed value in memory back to the TwoByteType.
observed = *reinterpret_cast<ValueType*>(&cas_observed_as_cas_type);
// Two cases:
// 1. compare-and-swap success
// a. `address` holds `new_value`
// b. `observed` becomes the new value after the assignment.
// Thus, the following `observed != new_value` is false,
// and the loop terminates.
// 2. compare-and-swap fails
// a. `address` holds a value different from `observed`, thus,
// the `new_value` is stale.
// b. `observed` becomes the fresh value observed in `address`.
// Thus, the following (observed != new_value) is true,
// and the loop continues. In the next iteration, the
// `new_value` is computed again using the fresh `observed`.
} while (observed != assumed);
}
struct AddFunc {
template <typename T>
__device__ __forceinline__ T operator()(T a, T b) const {
return a + b;
}
};
struct MulFunc {
template <typename T>
__device__ __forceinline__ T operator()(T a, T b) const {
return a * b;
}
};
struct MaxFunc {
template <typename T>
__device__ __forceinline__ T operator()(T a, T b) const {
return b > a ? b : a;
}
};
struct MinFunc {
template <typename T>
__device__ __forceinline__ T operator()(T a, T b) const {
return b < a ? b : a;
}
};
__device__ __forceinline__ void atomic_add(int8_t* address, int8_t value) {
atomic_byte_func_with_unit32_cas(address, value, AddFunc());
}
__device__ __forceinline__ void atomic_mul(int8_t* address, int8_t value) {
atomic_byte_func_with_unit32_cas(address, value, MulFunc());
}
__device__ __forceinline__ void atomic_max(int8_t* address, int8_t value) {
atomic_byte_func_with_unit32_cas(address, value, MaxFunc());
}
__device__ __forceinline__ void atomic_min(int8_t* address, int8_t value) {
atomic_byte_func_with_unit32_cas(address, value, MinFunc());
}
__device__ __forceinline__ void atomic_mul(half* address, half value) {
#if __CUDA_ARCH__ >= 700
atomic_binary_func(address, value, MulFunc());
#else
atomic_byte_func_with_unit32_cas(address, value, MulFunc());
#endif
}
__device__ __forceinline__ void atomic_max(half* address, half value) {
#if __CUDA_ARCH__ >= 700
atomic_binary_func(address, value, MaxFunc());
#else
atomic_byte_func_with_unit32_cas(address, value, MaxFunc());
#endif
}
__device__ __forceinline__ void atomic_min(half* address, half value) {
#if __CUDA_ARCH__ >= 700
atomic_binary_func(address, value, MinFunc());
#else
atomic_byte_func_with_unit32_cas(address, value, MinFunc());
#endif
}
__device__ __forceinline__ void atomic_mul(float* address, float value) {
atomic_binary_func(address, value, MulFunc());
}
__device__ __forceinline__ void atomic_max(float* address, float value) {
atomic_binary_func(address, value, MaxFunc());
}
__device__ __forceinline__ void atomic_min(float* address, float value) {
atomic_binary_func(address, value, MinFunc());
}
__device__ __forceinline__ void atomic_mul(double* address, double value) {
atomic_binary_func(address, value, MulFunc());
}
__device__ __forceinline__ void atomic_max(double* address, double value) {
atomic_binary_func(address, value, MaxFunc());
}
__device__ __forceinline__ void atomic_min(double* address, double value) {
atomic_binary_func(address, value, MinFunc());
}
} // namespace cuda
} // namespace onnxruntime

View file

@ -1046,7 +1046,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDom
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 14, Shape);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Size);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Transpose);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, ScatterElements);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 15, ScatterElements);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Softmax);
@ -1254,6 +1254,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, LessOrEqual);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, double, LessOrEqual);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterElements);
// Opset 17
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 17, float, LayerNormalization);
@ -1269,6 +1270,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMax);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterElements);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, Pad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Pad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad);
@ -1937,7 +1939,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 14, Shape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Size)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Transpose)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, ScatterElements)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 15, ScatterElements)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Softmax)>,
@ -2138,6 +2140,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, LessOrEqual)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, double, LessOrEqual)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterElements)>,
// Opset 17
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 17, float, LayerNormalization)>,
@ -2159,6 +2162,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterElements)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad)>,

View file

@ -95,7 +95,37 @@ struct OffsetCalculatorFor2D {
template <class T>
struct FuncAssignment {
__device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { start_addr[index] = value; }
__device__ __inline__ void operator()(T* start_addr, size_t index, T value) const {
start_addr[index] = value;
}
};
template <class T>
struct FuncAdd {
__device__ __inline__ void operator()(T* start_addr, size_t index, T value) const {
atomic_add(start_addr + index, value);
}
};
template <class T>
struct FuncMul {
__device__ __inline__ void operator()(T* start_addr, size_t index, T value) const {
atomic_mul(start_addr + index, value);
}
};
template <class T>
struct FuncMax {
__device__ __inline__ void operator()(T* start_addr, size_t index, T value) const {
atomic_max(start_addr + index, value);
}
};
template <class T>
struct FuncMin {
__device__ __inline__ void operator()(T* start_addr, size_t index, T value) const {
atomic_min(start_addr + index, value);
}
};
template <typename T, typename TIndex, bool IsGather, typename OffsetCalcT, typename TFunc>
@ -238,8 +268,24 @@ Status ScatterElementsImplInternal(cudaStream_t stream, const T* input_data, con
template <typename T, typename TIndex>
Status ScatterElementsImpl(cudaStream_t stream, const T* input_data, const TIndex* indices_data, const T* updates_data,
T* output_data, const GatherScatterElementsArgs& args) {
return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args,
FuncAssignment<T>());
if (args.operation == GatherScatterElementsArgs::Operation::NONE) {
return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args,
FuncAssignment<T>());
} else if (args.operation == GatherScatterElementsArgs::Operation::ADD) {
return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args,
FuncAdd<T>());
} else if (args.operation == GatherScatterElementsArgs::Operation::MUL) {
return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args,
FuncMul<T>());
} else if (args.operation == GatherScatterElementsArgs::Operation::MAX) {
return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args,
FuncMax<T>());
} else if (args.operation == GatherScatterElementsArgs::Operation::MIN) {
return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args,
FuncMin<T>());
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported reduction operator.");
}
}
#define GATHER_SCATTER_ELEMENTS_SPECIALIZED_TINDEX_IMPL(T, TIndex) \

View file

@ -10,6 +10,14 @@ namespace onnxruntime {
namespace cuda {
struct GatherScatterElementsArgs {
enum class Operation {
NONE,
ADD,
MUL,
MAX,
MIN
};
int64_t rank;
int64_t axis;
int64_t input_size;
@ -19,6 +27,9 @@ struct GatherScatterElementsArgs {
TArray<fast_divmod> indices_fdms;
TArray<int64_t> indices_strides;
int64_t indices_size;
// operation used to combine values associated the same
// memory location in the output tensor.
Operation operation;
};
template <typename T, typename TIndex>

View file

@ -27,7 +27,23 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterElements, kOnnxDomain, 11, 12, kCudaExe
DataTypeImpl::GetTensorType<int64_t>()}),
ScatterElements);
ONNX_OPERATOR_KERNEL_EX(ScatterElements, kOnnxDomain, 13, kCudaExecutionProvider,
ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterElements, kOnnxDomain, 13, 15, kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
.TypeConstraint("Tind",
std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>()}),
ScatterElements);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterElements, kOnnxDomain, 16, 17, kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
.TypeConstraint("Tind",
std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>()}),
ScatterElements);
ONNX_OPERATOR_KERNEL_EX(ScatterElements, kOnnxDomain, 18, kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
.TypeConstraint("Tind", std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(),
@ -106,6 +122,20 @@ Status ScatterElements::ComputeInternal(OpKernelContext* context) const {
TensorShapeVector indices_shape_vec = indices_shape.AsShapeVector();
CoalesceDimensions(input_shape_vec, indices_shape_vec, nullptr, axis, args);
if (reduction_ == "none") {
args.operation = GatherScatterElementsArgs::Operation::NONE;
} else if (reduction_ == "add") {
args.operation = GatherScatterElementsArgs::Operation::ADD;
} else if (reduction_ == "mul") {
args.operation = GatherScatterElementsArgs::Operation::MUL;
} else if (reduction_ == "min") {
args.operation = GatherScatterElementsArgs::Operation::MIN;
} else if (reduction_ == "max") {
args.operation = GatherScatterElementsArgs::Operation::MAX;
} else {
ORT_THROW("Unsupported reduction type");
}
// Use element size instead of concrete types so we can specialize less template functions to reduce binary size.
int dtype = GetElementType(input_tensor->DataType()->Size());
if (dtype == ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) {

View file

@ -14,6 +14,12 @@ class ScatterElements final : public CudaKernel {
ScatterElements(const OpKernelInfo& info) : CudaKernel(info) {
ORT_ENFORCE(info.GetAttr<int64_t>("axis", &axis_).IsOK(),
"Missing/Invalid 'axis' attribute value");
reduction_ = info.GetAttrOrDefault<std::string>("reduction", "none");
ORT_ENFORCE(reduction_ == "none" || reduction_ == "add" ||
reduction_ == "mul" || reduction_ == "max" ||
reduction_ == "min",
"Invalid reduction attribute value of ", reduction_);
}
~ScatterElements() = default;
Status ComputeInternal(OpKernelContext* context) const override;
@ -23,6 +29,10 @@ class ScatterElements final : public CudaKernel {
struct ComputeImpl;
int64_t axis_;
// "reduction" attribute has been defined since opset 13 but
// we never implemented it. Let's try to support them starting
// with opset 18.
std::string reduction_;
};
} // namespace cuda

View file

@ -59,5 +59,304 @@ __device__ __forceinline__ void AtomicAdd(T *start_addr, size_t index, const siz
atomic_add(start_addr + index, value);
}
// Disable default template instantiation.
// For every type T, we need to define a specialization
// to select the right type for calling atomicCAS.
template <typename T>
class AtomicCasType;
template<>
class AtomicCasType<int8_t> {
public:
using type = unsigned short int;
static const unsigned int mask = 0xffu;
};
template<>
class AtomicCasType<half> {
public:
using type = unsigned short int;
static const unsigned int mask = 0xffffu;
};
template<>
class AtomicCasType<float> {
public:
using type = unsigned int;
static const unsigned int mask = 0xffffffffu;
};
template<>
class AtomicCasType<double> {
public:
using type = unsigned long long int;
static const unsigned int mask = 0xffffffffu;
};
template<>
class AtomicCasType<int> {
public:
using type = int;
static const unsigned int mask = 0xffffffffu;
};
template<>
class AtomicCasType<int64_t> {
public:
using type = unsigned long long int;
static const unsigned int mask = 0xffffffffu;
};
// Obtained from pytorch/aten/src/ATen/cuda/Atomic.cuh.
//
// This function compute 8-bit atomic binary operation using 32-bit atomicCAS.
// It accumulate `val` into the `address` using the `func`.
// The accumulation is atomic (i.e., thread-safe).
//
// E.g., Assume ValueType is
// int8_t
// and BinaryFunc is
// struct AddFunc {
// __device__ __forceinline__ int8_t operator()(int8_t a, int8_t b) const {
// return a + b;
// }
// This function becomes atomic_add for int8_t.
template<typename ValueType, typename BinaryFunc>
__device__ __forceinline__ void atomic_byte_func_with_unit32_cas(ValueType* address, ValueType val, BinaryFunc func) {
// Assert to ensure the following bit-wise manipulation is correct.
static_assert(sizeof(ValueType) == 1 | sizeof(ValueType) == 2 | sizeof(ValueType) == 4,
"ValueType must be 1-byte, 2-byte or 4-byte large.");
// Number of bytes to the lower 4-byte aligned address.
// If the current address is b1010"10", then offset = b10 = 2,
// which means the current address is 2 bytes away from
// the lower 4-byte aligned address b1010"00".
size_t offset = (size_t)address & 3;
// Find an new 4-byte aligned address `address_as_ui` lower than
// or equal to `address`. Lower than `address` so that the actual
// int8_t byte is in the 4-byte word that we load.
//
// This address has the following properties:
// 1. It is 4-byte aligned.
// 2. It is lower than or equal to `address`.
// 3. De-referencing this address may return
// a uint32_t value that contains the same int8_t
// value indicated by `address`.
//
// E.g.,
// address = b101010
// offset = b101010 & b000011 = b10 = 2
// (char*)address - offset => (char*)b101010 - b000010 => b1010"00",
// which is (32-bit aligned).
uint32_t * address_as_ui = (uint32_t*)((char*)address - offset);
uint32_t old = *address_as_ui;
// E.g., offset = 2.
// address_as_ui is an address 2 bytes lower than `address`.
//
// ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 .....
// ^ ^ ^
// | | |
// | address <--- offset * 8 (bit)-----> address_as_ui
// | ^
// | |
// ------------------------- *address_as_ui -----------------------
//
// This visualization shows
// 1. the 32-bit word at address_as_ui.
// 2. the gap between address_as_ui and address.
// 3. *address_as_ui contains the int8_t value at `address`.
uint32_t shift = offset * 8;
uint32_t old_byte;
uint32_t newval;
uint32_t assumed;
do {
assumed = old;
// Select 8-bit value from 32-bit word. Assume offset = 2 (byte), so
// we want to select the 3rd byte (byte 2 below) from the word.
//
// Journey of a 32-bit value:
//
// ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 .....
//
// |
// | old >> offset * 8, where offset = 2.
// | Effectively, push lower two bytes
// | out of the word.
// V
//
// 00000000 | 00000000 | ..... byte 3 ..... | ..... byte 2 .....
//
// | apply bit-wise AND,
// | & 0xff (i.e., & b11111111),
// | so that we only keep
// | the byte of interest.
// | Otherwise, overflow may
// | happen when casting this
// | 32-bit value to int8_t.
// V
//
// 00000000 | 00000000 | 00000000 | ..... byte 2 .....
old_byte = (old >> shift) & AtomicCasType<ValueType>::mask;
// Compute new int8_t value and store it to newrawvalue.
// Journey of a 32-bit value (cont'd):
//
// newrawvalue
// ... new byte 2 ...
auto newrawvalue = func(val, reinterpret_cast<ValueType&>(old_byte));
// Put the new int8_t value back to 32-bit word.
// Also ensure that bits not occupied by the int8_t value are 0s.
//
// Journey of a 32-bit value (cont'd):
//
// reinterpret_cast<uint32_t&>(newrawvalue)
// random values | random values | random values | ... new byte 2 ...
//
// reinterpret_cast<uint32_t&>(newrawvalue) & AtomicCasType<ValueType>::mask
// 00000000 | 00000000 | 00000000 | ... new byte 2 ...
newval = reinterpret_cast<uint32_t&>(newrawvalue) & AtomicCasType<ValueType>::mask;
// Journey of a 32-bit value (cont'd):
//
// old
// ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 .....
//
// 0x000000ff
// 00000000 | 00000000 | 00000000 | 11111111
//
// 0x000000ff << shift
// 00000000 | 11111111 | 00000000 | 00000000
//
// ~(0x000000ff << shift)
// 11111111 | 00000000 | 11111111 | 11111111
//
// old & ~(0x000000ff << shift)
// ..... byte 3 ..... | 00000000 | ..... byte 1 ..... | ..... byte 0 .....
//
// newval << shift
// 00000000 | ... new byte 2 ... | 00000000 | 00000000
//
// (old & ~(0x000000ff << shift)) | (newval << shift)
// ..... byte 3 ..... | ... new byte 2 ... | ..... byte 1 ..... | ..... byte 0 .....
newval = (old & ~(AtomicCasType<ValueType>::mask << shift)) | (newval << shift);
old = atomicCAS(address_as_ui, assumed, newval);
} while (assumed != old);
}
// It accumulates `val` into the `address` using the `func`.
// This function is thread-safe (i.e., atomic).
template<typename ValueType, typename BinaryFunc>
__device__ __forceinline__ void atomic_binary_func(ValueType* address, ValueType val, BinaryFunc func) {
ValueType observed = *address, assumed, new_value;
using CasType = typename AtomicCasType<ValueType>::type;
static_assert(sizeof(ValueType) == sizeof(CasType),
"ValueType and CasType must have the same size for calling atomicCAS.");
auto address_as_cas_type = reinterpret_cast<CasType*>(address);
do {
// Record the value used to compute new value.
assumed = observed;
// Compute expected new value.
new_value = func(observed, val);
// Cast to aribitrary 2-byte type to desired integer type supported by atomicCAS.
// 4
// 8
auto observed_as_cas_type = *reinterpret_cast<CasType*>(&observed);
auto new_value_as_cas_type = *reinterpret_cast<CasType*>(&new_value);
// Call atomicCAS as if the 2-byte type variables are all unsigned short int.
// 4 unsigned int (or int)
// 8 unsigned long long int
auto cas_observed_as_cas_type = atomicCAS(address_as_cas_type, observed_as_cas_type, new_value_as_cas_type);
// Cast the freshly observed value in memory back to the TwoByteType.
observed = *reinterpret_cast<ValueType*>(&cas_observed_as_cas_type);
// Two cases:
// 1. compare-and-swap success
// a. `address` holds `new_value`
// b. `observed` becomes the new value after the assignment.
// Thus, the following `observed != new_value` is false,
// and the loop terminates.
// 2. compare-and-swap fails
// a. `address` holds a value different from `observed`, thus,
// the `new_value` is stale.
// b. `observed` becomes the fresh value observed in `address`.
// Thus, the following (observed != new_value) is true,
// and the loop continues. In the next iteration, the
// `new_value` is computed again using the fresh `observed`.
} while (observed != assumed);
}
struct AddFunc {
template <typename T>
__device__ __forceinline__ T operator()(T a, T b) const {
return a + b;
}
};
struct MulFunc {
template <typename T>
__device__ __forceinline__ T operator()(T a, T b) const {
return a * b;
}
};
struct MaxFunc {
template <typename T>
__device__ __forceinline__ T operator()(T a, T b) const {
return b > a ? b : a;
}
};
struct MinFunc {
template <typename T>
__device__ __forceinline__ T operator()(T a, T b) const {
return b < a ? b : a;
}
};
__device__ __forceinline__ void atomic_add(int8_t* address, int8_t value) {
atomic_byte_func_with_unit32_cas(address, value, AddFunc());
}
__device__ __forceinline__ void atomic_mul(int8_t* address, int8_t value) {
atomic_byte_func_with_unit32_cas(address, value, MulFunc());
}
__device__ __forceinline__ void atomic_max(int8_t* address, int8_t value) {
atomic_byte_func_with_unit32_cas(address, value, MaxFunc());
}
__device__ __forceinline__ void atomic_min(int8_t* address, int8_t value) {
atomic_byte_func_with_unit32_cas(address, value, MinFunc());
}
__device__ __forceinline__ void atomic_mul(half* address, half value) {
atomic_byte_func_with_unit32_cas(address, value, MulFunc());
}
__device__ __forceinline__ void atomic_max(half* address, half value) {
atomic_byte_func_with_unit32_cas(address, value, MaxFunc());
}
__device__ __forceinline__ void atomic_min(half* address, half value) {
atomic_byte_func_with_unit32_cas(address, value, MinFunc());
}
__device__ __forceinline__ void atomic_mul(float* address, float value) {
atomic_binary_func(address, value, MulFunc());
}
__device__ __forceinline__ void atomic_max(float* address, float value) {
atomic_binary_func(address, value, MaxFunc());
}
__device__ __forceinline__ void atomic_min(float* address, float value) {
atomic_binary_func(address, value, MinFunc());
}
__device__ __forceinline__ void atomic_mul(double* address, double value) {
atomic_binary_func(address, value, MulFunc());
}
__device__ __forceinline__ void atomic_max(double* address, double value) {
atomic_binary_func(address, value, MaxFunc());
}
__device__ __forceinline__ void atomic_min(double* address, double value) {
atomic_binary_func(address, value, MinFunc());
}
} // namespace rocm
} // namespace onnxruntime

View file

@ -1069,7 +1069,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDom
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 14, Shape);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Size);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Transpose);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, ScatterElements);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 15, ScatterElements);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Softmax);
@ -1290,6 +1290,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, float, LessOrEqual);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, double, LessOrEqual);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, 17, ScatterElements);
// Opset 17
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, float, LayerNormalization);
@ -1302,7 +1303,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, Pad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, bool, Pad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, ScatterElements);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, Split);
// Opset 19
@ -2004,7 +2005,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 14, Shape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Size)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Transpose)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, ScatterElements)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 15, ScatterElements)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Softmax)>,
@ -2225,6 +2226,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, float, LessOrEqual)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, double, LessOrEqual)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, 17, ScatterElements)>,
// Opset 17
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, float, LayerNormalization)>,
@ -2237,7 +2239,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, bool, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, ScatterElements)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, Split)>,
// Opset 19

View file

@ -302,5 +302,137 @@ TEST(Scatter, BoolInputWithAxis) {
scatter_bool_with_axis_tests("ScatterElements", 11);
}
TEST(ScatterElements, AddReduction) {
OpTester test("ScatterElements", 18);
test.AddAttribute<int64_t>("axis", 0);
test.AddAttribute<std::string>("reduction", "add");
test.AddInput<float>("data", {2, 3}, {-9.f, -4.f, -1.f, -7.f, -3.f, -6.f});
test.AddInput<int64_t>("indices", {4, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
test.AddInput<float>("updates", {4, 3}, {1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f});
test.AddOutput<float>("y", {2, 3}, {-9.f, -4.f, -1.f, -7.f + (1.f + 2.f + 3.f + 4.f), -3.f + (1.f + 2.f + 3.f + 4.f), -6.f + (1.f + 2.f + 3.f + 4.f)});
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
}
TEST(ScatterElements, AddReductionAxis1) {
OpTester test("ScatterElements", 18);
test.AddAttribute<int64_t>("axis", 1);
test.AddAttribute<std::string>("reduction", "add");
// update's slice shape is {2, 1}
test.AddInput<float>("data", {2, 3}, {9.f, 4.f, 1.f, 7.f, 3.f, 6.f});
test.AddInput<int64_t>("indices", {2, 4}, {1, 1, 1, 1, 1, 1, 1, 1});
test.AddInput<float>("updates", {2, 4}, {2.f, 5.f, 3.f, 6.f, 7.f, 9.f, 8.f, 10.f});
test.AddOutput<float>("y", {2, 3}, {9.f, 4.f + (2.f + 5.f + 3.f + 6.f), 1.f, 7.f, 3.f + (7.f + 9.f + 8.f + 10.f), 6.f});
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
}
TEST(ScatterElements, MulReduction) {
OpTester test("ScatterElements", 18);
test.AddAttribute<int64_t>("axis", 0);
test.AddAttribute<std::string>("reduction", "mul");
test.AddInput<float>("data", {2, 3}, {-9.f, -4.f, -1.f, -7.f, -3.f, -6.f});
test.AddInput<int64_t>("indices", {2, 3}, {1, 1, 1, 1, 1, 1});
test.AddInput<float>("updates", {2, 3}, {7.f, 3.f, 6.f, 7.f, 3.f, 6.f});
test.AddOutput<float>("y", {2, 3}, {-9.f, -4.f, -1.f, -7.f * 7.f * 7.f, -3.f * 3.f * 3.f, -6.f * 6.f * 6.f});
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
}
TEST(ScatterElements, MulReductionAxis1) {
OpTester test("ScatterElements", 18);
test.AddAttribute<int64_t>("axis", 1);
test.AddAttribute<std::string>("reduction", "mul");
// update's slice shape is {2, 1}
test.AddInput<float>("data", {2, 3}, {9.f, 4.f, 1.f, 7.f, 3.f, 6.f});
test.AddInput<int64_t>("indices", {2, 4}, {1, 1, 1, 1, 1, 1, 1, 1});
test.AddInput<float>("updates", {2, 4}, {2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f});
test.AddOutput<float>("y", {2, 3}, {9.f, 4.f * (2.f * 3.f * 4.f * 5.f), 1.f, 7.f, 3.f * (6.f * 7.f * 8.f * 9.f), 6.f});
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
}
TEST(ScatterElements, MaxReduction_MLFloat16) {
OpTester test("ScatterElements", 18);
test.AddAttribute<int64_t>("axis", 0);
test.AddAttribute<std::string>("reduction", "max");
test.AddInput<MLFloat16>("data", {2, 3}, ToFloat16({-9.f, -4.f, -1.f, -7.f, -3.f, -6.f}));
test.AddInput<int64_t>("indices", {2, 3}, {1, 1, 1, 1, 1, 1});
test.AddInput<MLFloat16>("updates", {2, 3}, ToFloat16({1.f, 5.f, 3.f, 7.f, 3.f, 6.f}));
test.AddOutput<MLFloat16>("y", {2, 3}, ToFloat16({-9.f, -4.f, -1.f, 7.f, 5.f, 6.f}));
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
}
TEST(ScatterElements, MaxReduction_Float) {
OpTester test("ScatterElements", 18);
test.AddAttribute<int64_t>("axis", 0);
test.AddAttribute<std::string>("reduction", "max");
test.AddInput<float>("data", {2, 3}, {-9.f, -4.f, -1.f, -7.f, -3.f, -6.f});
test.AddInput<int64_t>("indices", {2, 3}, {1, 1, 1, 1, 1, 1});
test.AddInput<float>("updates", {2, 3}, {1.f, 5.f, 3.f, 7.f, 3.f, 6.f});
test.AddOutput<float>("y", {2, 3}, {-9.f, -4.f, -1.f, 7.f, 5.f, 6.f});
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
}
TEST(ScatterElements, MaxReduction_Double) {
OpTester test("ScatterElements", 18);
test.AddAttribute<int64_t>("axis", 0);
test.AddAttribute<std::string>("reduction", "max");
test.AddInput<double>("data", {2, 3}, {-9.f, -4.f, -1.f, -7.f, -3.f, -6.f});
test.AddInput<int64_t>("indices", {2, 3}, {1, 1, 1, 1, 1, 1});
test.AddInput<double>("updates", {2, 3}, {1.f, 5.f, 3.f, 7.f, 3.f, 6.f});
test.AddOutput<double>("y", {2, 3}, {-9.f, -4.f, -1.f, 7.f, 5.f, 6.f});
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
}
TEST(ScatterElements, MinReduction_MLFloat16) {
OpTester test("ScatterElements", 18);
test.AddAttribute<int64_t>("axis", 0);
test.AddAttribute<std::string>("reduction", "min");
test.AddInput<MLFloat16>("data", {2, 3}, ToFloat16({-9.f, -4.f, -1.f, 8.f, -3.f, 5.f}));
test.AddInput<int64_t>("indices", {2, 3}, {1, 1, 1, 1, 1, 1});
test.AddInput<MLFloat16>("updates", {2, 3}, ToFloat16({1.f, 5.f, 3.f, 7.f, 3.f, 6.f}));
test.AddOutput<MLFloat16>("y", {2, 3}, ToFloat16({-9.f, -4.f, -1.f, 1.f, -3.f, 3.f}));
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
}
TEST(ScatterElements, MinReduction_Float) {
OpTester test("ScatterElements", 18);
test.AddAttribute<int64_t>("axis", 0);
test.AddAttribute<std::string>("reduction", "min");
test.AddInput<float>("data", {2, 3}, {-9.f, -4.f, -1.f, 8.f, -3.f, 5.f});
test.AddInput<int64_t>("indices", {2, 3}, {1, 1, 1, 1, 1, 1});
test.AddInput<float>("updates", {2, 3}, {1.f, 5.f, 3.f, 7.f, 3.f, 6.f});
test.AddOutput<float>("y", {2, 3}, {-9.f, -4.f, -1.f, 1.f, -3.f, 3.f});
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
}
TEST(ScatterElements, MinReduction_Double) {
OpTester test("ScatterElements", 18);
test.AddAttribute<int64_t>("axis", 0);
test.AddAttribute<std::string>("reduction", "min");
test.AddInput<double>("data", {2, 3}, {-9.f, -4.f, -1.f, 8.f, -3.f, 5.f});
test.AddInput<int64_t>("indices", {2, 3}, {1, 1, 1, 1, 1, 1});
test.AddInput<double>("updates", {2, 3}, {1.f, 5.f, 3.f, 7.f, 3.f, 6.f});
test.AddOutput<double>("y", {2, 3}, {-9.f, -4.f, -1.f, 1.f, -3.f, 3.f});
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
}
} // namespace test
} // namespace onnxruntime