mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Update ScatterElements to Support Opset 13, 15, 18 (#19198)
`ScatterElements` in opset 18 has been around for a while. However, the highest opset supporting `ScatterElements` in ORT is 13. This PR implement this op in CUDA EP by replacing `assignment` in the current CDUA kernel with `atomic reduction` (e.g., atomic add, atomic max). A series of fundamental atomic functions (e.g., atomic max for int8_t and half) are implemented in `common.cuh`; the implementation is general enough to cover old CUDA and new CUDA versions. - The core changes are in `cuda/atomic/common.cuh` with very detailed documentation including `bit-wise operation's visualization`. They are also copied to `rocm/atomic/common.cuh` to support AMD GPU. - `/cuda/tensor/gather_elements_impl.cu` contains small changes to call the new atomic functions to support new `reduction` behavior in new `ScatterElements`. - New `ScatterElements` are defined in `rocm_execution_provider.cc` and `cuda_execution_provider.cc`.
This commit is contained in:
parent
3e17ca3dab
commit
ffc3431a66
11 changed files with 858 additions and 25 deletions
|
|
@ -744,7 +744,9 @@ Do not modify directly.*
|
|||
|||[9, 10]|**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||8|**I** = tensor(int64)<br/> **V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|Scatter|*in* data:**T**<br> *in* indices:**Tind**<br> *in* updates:**T**<br> *out* output:**T**|[9, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|
||||
|ScatterElements|*in* data:**T**<br> *in* indices:**Tind**<br> *in* updates:**T**<br> *out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|
||||
|ScatterElements|*in* data:**T**<br> *in* indices:**Tind**<br> *in* updates:**T**<br> *out* output:**T**|18+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|
||||
|||[16, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|
||||
|||[13, 15]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|
||||
|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|
||||
|ScatterND|*in* data:**T**<br> *in* indices:**tensor(int64)**<br> *in* updates:**T**<br> *out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|
|
|
|||
|
|
@ -198,13 +198,6 @@ struct Func_Min<std::string> {
|
|||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Func_Min<MLFloat16> {
|
||||
void operator()(MLFloat16*, const MLFloat16*) const {
|
||||
ORT_NOT_IMPLEMENTED("CPU execution provider: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'min'.");
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Func_Min<BFloat16> {
|
||||
void operator()(BFloat16*, const BFloat16*) const {
|
||||
|
|
@ -233,13 +226,6 @@ struct Func_Max<std::string> {
|
|||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Func_Max<MLFloat16> {
|
||||
void operator()(MLFloat16*, const MLFloat16*) const {
|
||||
ORT_NOT_IMPLEMENTED("CPU execution provider: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'max'.");
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Func_Max<BFloat16> {
|
||||
void operator()(BFloat16*, const BFloat16*) const {
|
||||
|
|
|
|||
|
|
@ -122,5 +122,316 @@ __device__ __forceinline__ void AtomicAdd<half>(half* start_addr, size_t index,
|
|||
#endif
|
||||
}
|
||||
|
||||
// Disable default template instantiation.
|
||||
// For every type T, we need to define a specialization
|
||||
// to select the right type for calling atomicCAS.
|
||||
template <typename T>
|
||||
class AtomicCasType;
|
||||
|
||||
template<>
|
||||
class AtomicCasType<int8_t> {
|
||||
public:
|
||||
using type = unsigned short int;
|
||||
static const unsigned int mask = 0xffu;
|
||||
};
|
||||
|
||||
template<>
|
||||
class AtomicCasType<half> {
|
||||
public:
|
||||
using type = unsigned short int;
|
||||
static const unsigned int mask = 0xffffu;
|
||||
};
|
||||
|
||||
template<>
|
||||
class AtomicCasType<float> {
|
||||
public:
|
||||
using type = unsigned int;
|
||||
static const unsigned int mask = 0xffffffffu;
|
||||
};
|
||||
|
||||
template<>
|
||||
class AtomicCasType<double> {
|
||||
public:
|
||||
using type = unsigned long long int;
|
||||
static const unsigned int mask = 0xffffffffu;
|
||||
};
|
||||
|
||||
template<>
|
||||
class AtomicCasType<int> {
|
||||
public:
|
||||
using type = int;
|
||||
static const unsigned int mask = 0xffffffffu;
|
||||
};
|
||||
|
||||
template<>
|
||||
class AtomicCasType<int64_t> {
|
||||
public:
|
||||
using type = unsigned long long int;
|
||||
static const unsigned int mask = 0xffffffffu;
|
||||
};
|
||||
|
||||
// Obtained from pytorch/aten/src/ATen/cuda/Atomic.cuh.
|
||||
//
|
||||
// This function compute 8-bit atomic binary operation using 32-bit atomicCAS.
|
||||
// It accumulate `val` into the `address` using the `func`.
|
||||
// The accumulation is atomic (i.e., thread-safe).
|
||||
//
|
||||
// E.g., Assume ValueType is
|
||||
// int8_t
|
||||
// and BinaryFunc is
|
||||
// struct AddFunc {
|
||||
// __device__ __forceinline__ int8_t operator()(int8_t a, int8_t b) const {
|
||||
// return a + b;
|
||||
// }
|
||||
// This function becomes atomic_add for int8_t.
|
||||
template<typename ValueType, typename BinaryFunc>
|
||||
__device__ __forceinline__ void atomic_byte_func_with_unit32_cas(ValueType* address, ValueType val, BinaryFunc func) {
|
||||
// Assert to ensure the following bit-wise manipulation is correct.
|
||||
static_assert(sizeof(ValueType) == 1 | sizeof(ValueType) == 2 | sizeof(ValueType) == 4,
|
||||
"ValueType must be 1-byte, 2-byte or 4-byte large.");
|
||||
// Number of bytes to the lower 4-byte aligned address.
|
||||
// If the current address is b1010"10", then offset = b10 = 2,
|
||||
// which means the current address is 2 bytes away from
|
||||
// the lower 4-byte aligned address b1010"00".
|
||||
size_t offset = (size_t)address & 3;
|
||||
// Find an new 4-byte aligned address `address_as_ui` lower than
|
||||
// or equal to `address`. Lower than `address` so that the actual
|
||||
// int8_t byte is in the 4-byte word that we load.
|
||||
//
|
||||
// This address has the following properties:
|
||||
// 1. It is 4-byte aligned.
|
||||
// 2. It is lower than or equal to `address`.
|
||||
// 3. De-referencing this address may return
|
||||
// a uint32_t value that contains the same int8_t
|
||||
// value indicated by `address`.
|
||||
//
|
||||
// E.g.,
|
||||
// address = b101010
|
||||
// offset = b101010 & b000011 = b10 = 2
|
||||
// (char*)address - offset => (char*)b101010 - b000010 => b1010"00",
|
||||
// which is (32-bit aligned).
|
||||
uint32_t * address_as_ui = (uint32_t*)((char*)address - offset);
|
||||
uint32_t old = *address_as_ui;
|
||||
// E.g., offset = 2.
|
||||
// address_as_ui is an address 2 bytes lower than `address`.
|
||||
//
|
||||
// ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 .....
|
||||
// ^ ^ ^
|
||||
// | | |
|
||||
// | address <--- offset * 8 (bit)-----> address_as_ui
|
||||
// | ^
|
||||
// | |
|
||||
// ------------------------- *address_as_ui -----------------------
|
||||
//
|
||||
// This visualization shows
|
||||
// 1. the 32-bit word at address_as_ui.
|
||||
// 2. the gap between address_as_ui and address.
|
||||
// 3. *address_as_ui contains the int8_t value at `address`.
|
||||
uint32_t shift = offset * 8;
|
||||
uint32_t old_byte;
|
||||
uint32_t newval;
|
||||
uint32_t assumed;
|
||||
do {
|
||||
assumed = old;
|
||||
// Select 8-bit value from 32-bit word. Assume offset = 2 (byte), so
|
||||
// we want to select the 3rd byte (byte 2 below) from the word.
|
||||
//
|
||||
// Journey of a 32-bit value:
|
||||
//
|
||||
// ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 .....
|
||||
//
|
||||
// |
|
||||
// | old >> offset * 8, where offset = 2.
|
||||
// | Effectively, push lower two bytes
|
||||
// | out of the word.
|
||||
// V
|
||||
//
|
||||
// 00000000 | 00000000 | ..... byte 3 ..... | ..... byte 2 .....
|
||||
//
|
||||
// | apply bit-wise AND,
|
||||
// | & 0xff (i.e., & b11111111),
|
||||
// | so that we only keep
|
||||
// | the byte of interest.
|
||||
// | Otherwise, overflow may
|
||||
// | happen when casting this
|
||||
// | 32-bit value to int8_t.
|
||||
// V
|
||||
//
|
||||
// 00000000 | 00000000 | 00000000 | ..... byte 2 .....
|
||||
old_byte = (old >> shift) & AtomicCasType<ValueType>::mask;
|
||||
// Compute new int8_t value and store it to newrawvalue.
|
||||
// Journey of a 32-bit value (cont'd):
|
||||
//
|
||||
// newrawvalue
|
||||
// ... new byte 2 ...
|
||||
auto newrawvalue = func(val, reinterpret_cast<ValueType&>(old_byte));
|
||||
// Put the new int8_t value back to 32-bit word.
|
||||
// Also ensure that bits not occupied by the int8_t value are 0s.
|
||||
//
|
||||
// Journey of a 32-bit value (cont'd):
|
||||
//
|
||||
// reinterpret_cast<uint32_t&>(newrawvalue)
|
||||
// random values | random values | random values | ... new byte 2 ...
|
||||
//
|
||||
// reinterpret_cast<uint32_t&>(newrawvalue) & AtomicCasType<ValueType>::mask
|
||||
// 00000000 | 00000000 | 00000000 | ... new byte 2 ...
|
||||
newval = reinterpret_cast<uint32_t&>(newrawvalue) & AtomicCasType<ValueType>::mask;
|
||||
// Journey of a 32-bit value (cont'd):
|
||||
//
|
||||
// old
|
||||
// ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 .....
|
||||
//
|
||||
// 0x000000ff
|
||||
// 00000000 | 00000000 | 00000000 | 11111111
|
||||
//
|
||||
// 0x000000ff << shift
|
||||
// 00000000 | 11111111 | 00000000 | 00000000
|
||||
//
|
||||
// ~(0x000000ff << shift)
|
||||
// 11111111 | 00000000 | 11111111 | 11111111
|
||||
//
|
||||
// old & ~(0x000000ff << shift)
|
||||
// ..... byte 3 ..... | 00000000 | ..... byte 1 ..... | ..... byte 0 .....
|
||||
//
|
||||
// newval << shift
|
||||
// 00000000 | ... new byte 2 ... | 00000000 | 00000000
|
||||
//
|
||||
// (old & ~(0x000000ff << shift)) | (newval << shift)
|
||||
// ..... byte 3 ..... | ... new byte 2 ... | ..... byte 1 ..... | ..... byte 0 .....
|
||||
newval = (old & ~(AtomicCasType<ValueType>::mask << shift)) | (newval << shift);
|
||||
old = atomicCAS(address_as_ui, assumed, newval);
|
||||
} while (assumed != old);
|
||||
}
|
||||
|
||||
// It accumulates `val` into the `address` using the `func`.
|
||||
// This function is thread-safe (i.e., atomic).
|
||||
template<typename ValueType, typename BinaryFunc>
|
||||
__device__ __forceinline__ void atomic_binary_func(ValueType* address, ValueType val, BinaryFunc func) {
|
||||
ValueType observed = *address, assumed, new_value;
|
||||
using CasType = typename AtomicCasType<ValueType>::type;
|
||||
static_assert(sizeof(ValueType) == sizeof(CasType),
|
||||
"ValueType and CasType must have the same size for calling atomicCAS.");
|
||||
auto address_as_cas_type = reinterpret_cast<CasType*>(address);
|
||||
do {
|
||||
// Record the value used to compute new value.
|
||||
assumed = observed;
|
||||
|
||||
// Compute expected new value.
|
||||
new_value = func(observed, val);
|
||||
|
||||
// Cast to aribitrary 2-byte type to desired integer type supported by atomicCAS.
|
||||
// 4
|
||||
// 8
|
||||
auto observed_as_cas_type = *reinterpret_cast<CasType*>(&observed);
|
||||
auto new_value_as_cas_type = *reinterpret_cast<CasType*>(&new_value);
|
||||
|
||||
// Call atomicCAS as if the 2-byte type variables are all unsigned short int.
|
||||
// 4 unsigned int (or int)
|
||||
// 8 unsigned long long int
|
||||
auto cas_observed_as_cas_type = atomicCAS(address_as_cas_type, observed_as_cas_type, new_value_as_cas_type);
|
||||
|
||||
// Cast the freshly observed value in memory back to the TwoByteType.
|
||||
observed = *reinterpret_cast<ValueType*>(&cas_observed_as_cas_type);
|
||||
|
||||
// Two cases:
|
||||
// 1. compare-and-swap success
|
||||
// a. `address` holds `new_value`
|
||||
// b. `observed` becomes the new value after the assignment.
|
||||
// Thus, the following `observed != new_value` is false,
|
||||
// and the loop terminates.
|
||||
// 2. compare-and-swap fails
|
||||
// a. `address` holds a value different from `observed`, thus,
|
||||
// the `new_value` is stale.
|
||||
// b. `observed` becomes the fresh value observed in `address`.
|
||||
// Thus, the following (observed != new_value) is true,
|
||||
// and the loop continues. In the next iteration, the
|
||||
// `new_value` is computed again using the fresh `observed`.
|
||||
} while (observed != assumed);
|
||||
}
|
||||
|
||||
struct AddFunc {
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T operator()(T a, T b) const {
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
|
||||
struct MulFunc {
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T operator()(T a, T b) const {
|
||||
return a * b;
|
||||
}
|
||||
};
|
||||
|
||||
struct MaxFunc {
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T operator()(T a, T b) const {
|
||||
return b > a ? b : a;
|
||||
}
|
||||
};
|
||||
|
||||
struct MinFunc {
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T operator()(T a, T b) const {
|
||||
return b < a ? b : a;
|
||||
}
|
||||
};
|
||||
|
||||
__device__ __forceinline__ void atomic_add(int8_t* address, int8_t value) {
|
||||
atomic_byte_func_with_unit32_cas(address, value, AddFunc());
|
||||
}
|
||||
__device__ __forceinline__ void atomic_mul(int8_t* address, int8_t value) {
|
||||
atomic_byte_func_with_unit32_cas(address, value, MulFunc());
|
||||
}
|
||||
__device__ __forceinline__ void atomic_max(int8_t* address, int8_t value) {
|
||||
atomic_byte_func_with_unit32_cas(address, value, MaxFunc());
|
||||
}
|
||||
__device__ __forceinline__ void atomic_min(int8_t* address, int8_t value) {
|
||||
atomic_byte_func_with_unit32_cas(address, value, MinFunc());
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void atomic_mul(half* address, half value) {
|
||||
#if __CUDA_ARCH__ >= 700
|
||||
atomic_binary_func(address, value, MulFunc());
|
||||
#else
|
||||
atomic_byte_func_with_unit32_cas(address, value, MulFunc());
|
||||
#endif
|
||||
}
|
||||
__device__ __forceinline__ void atomic_max(half* address, half value) {
|
||||
#if __CUDA_ARCH__ >= 700
|
||||
atomic_binary_func(address, value, MaxFunc());
|
||||
#else
|
||||
atomic_byte_func_with_unit32_cas(address, value, MaxFunc());
|
||||
#endif
|
||||
}
|
||||
__device__ __forceinline__ void atomic_min(half* address, half value) {
|
||||
#if __CUDA_ARCH__ >= 700
|
||||
atomic_binary_func(address, value, MinFunc());
|
||||
#else
|
||||
atomic_byte_func_with_unit32_cas(address, value, MinFunc());
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void atomic_mul(float* address, float value) {
|
||||
atomic_binary_func(address, value, MulFunc());
|
||||
}
|
||||
__device__ __forceinline__ void atomic_max(float* address, float value) {
|
||||
atomic_binary_func(address, value, MaxFunc());
|
||||
}
|
||||
__device__ __forceinline__ void atomic_min(float* address, float value) {
|
||||
atomic_binary_func(address, value, MinFunc());
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void atomic_mul(double* address, double value) {
|
||||
atomic_binary_func(address, value, MulFunc());
|
||||
}
|
||||
__device__ __forceinline__ void atomic_max(double* address, double value) {
|
||||
atomic_binary_func(address, value, MaxFunc());
|
||||
}
|
||||
__device__ __forceinline__ void atomic_min(double* address, double value) {
|
||||
atomic_binary_func(address, value, MinFunc());
|
||||
}
|
||||
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -1046,7 +1046,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDom
|
|||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 14, Shape);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Size);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Transpose);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, ScatterElements);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 15, ScatterElements);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, Slice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, Slice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Softmax);
|
||||
|
|
@ -1254,6 +1254,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, LessOrEqual);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, double, LessOrEqual);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterElements);
|
||||
|
||||
// Opset 17
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 17, float, LayerNormalization);
|
||||
|
|
@ -1269,6 +1270,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMax);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMax);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMax);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterElements);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, Pad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Pad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad);
|
||||
|
|
@ -1937,7 +1939,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 14, Shape)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Size)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Transpose)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, ScatterElements)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 15, ScatterElements)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Softmax)>,
|
||||
|
|
@ -2138,6 +2140,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, LessOrEqual)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, double, LessOrEqual)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterElements)>,
|
||||
|
||||
// Opset 17
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 17, float, LayerNormalization)>,
|
||||
|
|
@ -2159,6 +2162,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterElements)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, Pad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Pad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad)>,
|
||||
|
|
|
|||
|
|
@ -95,7 +95,37 @@ struct OffsetCalculatorFor2D {
|
|||
|
||||
template <class T>
|
||||
struct FuncAssignment {
|
||||
__device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { start_addr[index] = value; }
|
||||
__device__ __inline__ void operator()(T* start_addr, size_t index, T value) const {
|
||||
start_addr[index] = value;
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct FuncAdd {
|
||||
__device__ __inline__ void operator()(T* start_addr, size_t index, T value) const {
|
||||
atomic_add(start_addr + index, value);
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct FuncMul {
|
||||
__device__ __inline__ void operator()(T* start_addr, size_t index, T value) const {
|
||||
atomic_mul(start_addr + index, value);
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct FuncMax {
|
||||
__device__ __inline__ void operator()(T* start_addr, size_t index, T value) const {
|
||||
atomic_max(start_addr + index, value);
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct FuncMin {
|
||||
__device__ __inline__ void operator()(T* start_addr, size_t index, T value) const {
|
||||
atomic_min(start_addr + index, value);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename TIndex, bool IsGather, typename OffsetCalcT, typename TFunc>
|
||||
|
|
@ -238,8 +268,24 @@ Status ScatterElementsImplInternal(cudaStream_t stream, const T* input_data, con
|
|||
template <typename T, typename TIndex>
|
||||
Status ScatterElementsImpl(cudaStream_t stream, const T* input_data, const TIndex* indices_data, const T* updates_data,
|
||||
T* output_data, const GatherScatterElementsArgs& args) {
|
||||
return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args,
|
||||
FuncAssignment<T>());
|
||||
if (args.operation == GatherScatterElementsArgs::Operation::NONE) {
|
||||
return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args,
|
||||
FuncAssignment<T>());
|
||||
} else if (args.operation == GatherScatterElementsArgs::Operation::ADD) {
|
||||
return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args,
|
||||
FuncAdd<T>());
|
||||
} else if (args.operation == GatherScatterElementsArgs::Operation::MUL) {
|
||||
return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args,
|
||||
FuncMul<T>());
|
||||
} else if (args.operation == GatherScatterElementsArgs::Operation::MAX) {
|
||||
return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args,
|
||||
FuncMax<T>());
|
||||
} else if (args.operation == GatherScatterElementsArgs::Operation::MIN) {
|
||||
return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args,
|
||||
FuncMin<T>());
|
||||
} else {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported reduction operator.");
|
||||
}
|
||||
}
|
||||
|
||||
#define GATHER_SCATTER_ELEMENTS_SPECIALIZED_TINDEX_IMPL(T, TIndex) \
|
||||
|
|
|
|||
|
|
@ -10,6 +10,14 @@ namespace onnxruntime {
|
|||
namespace cuda {
|
||||
|
||||
struct GatherScatterElementsArgs {
|
||||
enum class Operation {
|
||||
NONE,
|
||||
ADD,
|
||||
MUL,
|
||||
MAX,
|
||||
MIN
|
||||
};
|
||||
|
||||
int64_t rank;
|
||||
int64_t axis;
|
||||
int64_t input_size;
|
||||
|
|
@ -19,6 +27,9 @@ struct GatherScatterElementsArgs {
|
|||
TArray<fast_divmod> indices_fdms;
|
||||
TArray<int64_t> indices_strides;
|
||||
int64_t indices_size;
|
||||
// operation used to combine values associated the same
|
||||
// memory location in the output tensor.
|
||||
Operation operation;
|
||||
};
|
||||
|
||||
template <typename T, typename TIndex>
|
||||
|
|
|
|||
|
|
@ -27,7 +27,23 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterElements, kOnnxDomain, 11, 12, kCudaExe
|
|||
DataTypeImpl::GetTensorType<int64_t>()}),
|
||||
ScatterElements);
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(ScatterElements, kOnnxDomain, 13, kCudaExecutionProvider,
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterElements, kOnnxDomain, 13, 15, kCudaExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
|
||||
.TypeConstraint("Tind",
|
||||
std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(),
|
||||
DataTypeImpl::GetTensorType<int64_t>()}),
|
||||
ScatterElements);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterElements, kOnnxDomain, 16, 17, kCudaExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
|
||||
.TypeConstraint("Tind",
|
||||
std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(),
|
||||
DataTypeImpl::GetTensorType<int64_t>()}),
|
||||
ScatterElements);
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(ScatterElements, kOnnxDomain, 18, kCudaExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
|
||||
.TypeConstraint("Tind", std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(),
|
||||
|
|
@ -106,6 +122,20 @@ Status ScatterElements::ComputeInternal(OpKernelContext* context) const {
|
|||
TensorShapeVector indices_shape_vec = indices_shape.AsShapeVector();
|
||||
CoalesceDimensions(input_shape_vec, indices_shape_vec, nullptr, axis, args);
|
||||
|
||||
if (reduction_ == "none") {
|
||||
args.operation = GatherScatterElementsArgs::Operation::NONE;
|
||||
} else if (reduction_ == "add") {
|
||||
args.operation = GatherScatterElementsArgs::Operation::ADD;
|
||||
} else if (reduction_ == "mul") {
|
||||
args.operation = GatherScatterElementsArgs::Operation::MUL;
|
||||
} else if (reduction_ == "min") {
|
||||
args.operation = GatherScatterElementsArgs::Operation::MIN;
|
||||
} else if (reduction_ == "max") {
|
||||
args.operation = GatherScatterElementsArgs::Operation::MAX;
|
||||
} else {
|
||||
ORT_THROW("Unsupported reduction type");
|
||||
}
|
||||
|
||||
// Use element size instead of concrete types so we can specialize less template functions to reduce binary size.
|
||||
int dtype = GetElementType(input_tensor->DataType()->Size());
|
||||
if (dtype == ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) {
|
||||
|
|
|
|||
|
|
@ -14,6 +14,12 @@ class ScatterElements final : public CudaKernel {
|
|||
ScatterElements(const OpKernelInfo& info) : CudaKernel(info) {
|
||||
ORT_ENFORCE(info.GetAttr<int64_t>("axis", &axis_).IsOK(),
|
||||
"Missing/Invalid 'axis' attribute value");
|
||||
reduction_ = info.GetAttrOrDefault<std::string>("reduction", "none");
|
||||
|
||||
ORT_ENFORCE(reduction_ == "none" || reduction_ == "add" ||
|
||||
reduction_ == "mul" || reduction_ == "max" ||
|
||||
reduction_ == "min",
|
||||
"Invalid reduction attribute value of ", reduction_);
|
||||
}
|
||||
~ScatterElements() = default;
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
|
|
@ -23,6 +29,10 @@ class ScatterElements final : public CudaKernel {
|
|||
struct ComputeImpl;
|
||||
|
||||
int64_t axis_;
|
||||
// "reduction" attribute has been defined since opset 13 but
|
||||
// we never implemented it. Let's try to support them starting
|
||||
// with opset 18.
|
||||
std::string reduction_;
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
|
|
|
|||
|
|
@ -59,5 +59,304 @@ __device__ __forceinline__ void AtomicAdd(T *start_addr, size_t index, const siz
|
|||
atomic_add(start_addr + index, value);
|
||||
}
|
||||
|
||||
// Disable default template instantiation.
|
||||
// For every type T, we need to define a specialization
|
||||
// to select the right type for calling atomicCAS.
|
||||
template <typename T>
|
||||
class AtomicCasType;
|
||||
|
||||
template<>
|
||||
class AtomicCasType<int8_t> {
|
||||
public:
|
||||
using type = unsigned short int;
|
||||
static const unsigned int mask = 0xffu;
|
||||
};
|
||||
|
||||
template<>
|
||||
class AtomicCasType<half> {
|
||||
public:
|
||||
using type = unsigned short int;
|
||||
static const unsigned int mask = 0xffffu;
|
||||
};
|
||||
|
||||
template<>
|
||||
class AtomicCasType<float> {
|
||||
public:
|
||||
using type = unsigned int;
|
||||
static const unsigned int mask = 0xffffffffu;
|
||||
};
|
||||
|
||||
template<>
|
||||
class AtomicCasType<double> {
|
||||
public:
|
||||
using type = unsigned long long int;
|
||||
static const unsigned int mask = 0xffffffffu;
|
||||
};
|
||||
|
||||
template<>
|
||||
class AtomicCasType<int> {
|
||||
public:
|
||||
using type = int;
|
||||
static const unsigned int mask = 0xffffffffu;
|
||||
};
|
||||
|
||||
template<>
|
||||
class AtomicCasType<int64_t> {
|
||||
public:
|
||||
using type = unsigned long long int;
|
||||
static const unsigned int mask = 0xffffffffu;
|
||||
};
|
||||
|
||||
// Obtained from pytorch/aten/src/ATen/cuda/Atomic.cuh.
|
||||
//
|
||||
// This function compute 8-bit atomic binary operation using 32-bit atomicCAS.
|
||||
// It accumulate `val` into the `address` using the `func`.
|
||||
// The accumulation is atomic (i.e., thread-safe).
|
||||
//
|
||||
// E.g., Assume ValueType is
|
||||
// int8_t
|
||||
// and BinaryFunc is
|
||||
// struct AddFunc {
|
||||
// __device__ __forceinline__ int8_t operator()(int8_t a, int8_t b) const {
|
||||
// return a + b;
|
||||
// }
|
||||
// This function becomes atomic_add for int8_t.
|
||||
template<typename ValueType, typename BinaryFunc>
|
||||
__device__ __forceinline__ void atomic_byte_func_with_unit32_cas(ValueType* address, ValueType val, BinaryFunc func) {
|
||||
// Assert to ensure the following bit-wise manipulation is correct.
|
||||
static_assert(sizeof(ValueType) == 1 | sizeof(ValueType) == 2 | sizeof(ValueType) == 4,
|
||||
"ValueType must be 1-byte, 2-byte or 4-byte large.");
|
||||
// Number of bytes to the lower 4-byte aligned address.
|
||||
// If the current address is b1010"10", then offset = b10 = 2,
|
||||
// which means the current address is 2 bytes away from
|
||||
// the lower 4-byte aligned address b1010"00".
|
||||
size_t offset = (size_t)address & 3;
|
||||
// Find an new 4-byte aligned address `address_as_ui` lower than
|
||||
// or equal to `address`. Lower than `address` so that the actual
|
||||
// int8_t byte is in the 4-byte word that we load.
|
||||
//
|
||||
// This address has the following properties:
|
||||
// 1. It is 4-byte aligned.
|
||||
// 2. It is lower than or equal to `address`.
|
||||
// 3. De-referencing this address may return
|
||||
// a uint32_t value that contains the same int8_t
|
||||
// value indicated by `address`.
|
||||
//
|
||||
// E.g.,
|
||||
// address = b101010
|
||||
// offset = b101010 & b000011 = b10 = 2
|
||||
// (char*)address - offset => (char*)b101010 - b000010 => b1010"00",
|
||||
// which is (32-bit aligned).
|
||||
uint32_t * address_as_ui = (uint32_t*)((char*)address - offset);
|
||||
uint32_t old = *address_as_ui;
|
||||
// E.g., offset = 2.
|
||||
// address_as_ui is an address 2 bytes lower than `address`.
|
||||
//
|
||||
// ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 .....
|
||||
// ^ ^ ^
|
||||
// | | |
|
||||
// | address <--- offset * 8 (bit)-----> address_as_ui
|
||||
// | ^
|
||||
// | |
|
||||
// ------------------------- *address_as_ui -----------------------
|
||||
//
|
||||
// This visualization shows
|
||||
// 1. the 32-bit word at address_as_ui.
|
||||
// 2. the gap between address_as_ui and address.
|
||||
// 3. *address_as_ui contains the int8_t value at `address`.
|
||||
uint32_t shift = offset * 8;
|
||||
uint32_t old_byte;
|
||||
uint32_t newval;
|
||||
uint32_t assumed;
|
||||
do {
|
||||
assumed = old;
|
||||
// Select 8-bit value from 32-bit word. Assume offset = 2 (byte), so
|
||||
// we want to select the 3rd byte (byte 2 below) from the word.
|
||||
//
|
||||
// Journey of a 32-bit value:
|
||||
//
|
||||
// ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 .....
|
||||
//
|
||||
// |
|
||||
// | old >> offset * 8, where offset = 2.
|
||||
// | Effectively, push lower two bytes
|
||||
// | out of the word.
|
||||
// V
|
||||
//
|
||||
// 00000000 | 00000000 | ..... byte 3 ..... | ..... byte 2 .....
|
||||
//
|
||||
// | apply bit-wise AND,
|
||||
// | & 0xff (i.e., & b11111111),
|
||||
// | so that we only keep
|
||||
// | the byte of interest.
|
||||
// | Otherwise, overflow may
|
||||
// | happen when casting this
|
||||
// | 32-bit value to int8_t.
|
||||
// V
|
||||
//
|
||||
// 00000000 | 00000000 | 00000000 | ..... byte 2 .....
|
||||
old_byte = (old >> shift) & AtomicCasType<ValueType>::mask;
|
||||
// Compute new int8_t value and store it to newrawvalue.
|
||||
// Journey of a 32-bit value (cont'd):
|
||||
//
|
||||
// newrawvalue
|
||||
// ... new byte 2 ...
|
||||
auto newrawvalue = func(val, reinterpret_cast<ValueType&>(old_byte));
|
||||
// Put the new int8_t value back to 32-bit word.
|
||||
// Also ensure that bits not occupied by the int8_t value are 0s.
|
||||
//
|
||||
// Journey of a 32-bit value (cont'd):
|
||||
//
|
||||
// reinterpret_cast<uint32_t&>(newrawvalue)
|
||||
// random values | random values | random values | ... new byte 2 ...
|
||||
//
|
||||
// reinterpret_cast<uint32_t&>(newrawvalue) & AtomicCasType<ValueType>::mask
|
||||
// 00000000 | 00000000 | 00000000 | ... new byte 2 ...
|
||||
newval = reinterpret_cast<uint32_t&>(newrawvalue) & AtomicCasType<ValueType>::mask;
|
||||
// Journey of a 32-bit value (cont'd):
|
||||
//
|
||||
// old
|
||||
// ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 .....
|
||||
//
|
||||
// 0x000000ff
|
||||
// 00000000 | 00000000 | 00000000 | 11111111
|
||||
//
|
||||
// 0x000000ff << shift
|
||||
// 00000000 | 11111111 | 00000000 | 00000000
|
||||
//
|
||||
// ~(0x000000ff << shift)
|
||||
// 11111111 | 00000000 | 11111111 | 11111111
|
||||
//
|
||||
// old & ~(0x000000ff << shift)
|
||||
// ..... byte 3 ..... | 00000000 | ..... byte 1 ..... | ..... byte 0 .....
|
||||
//
|
||||
// newval << shift
|
||||
// 00000000 | ... new byte 2 ... | 00000000 | 00000000
|
||||
//
|
||||
// (old & ~(0x000000ff << shift)) | (newval << shift)
|
||||
// ..... byte 3 ..... | ... new byte 2 ... | ..... byte 1 ..... | ..... byte 0 .....
|
||||
newval = (old & ~(AtomicCasType<ValueType>::mask << shift)) | (newval << shift);
|
||||
old = atomicCAS(address_as_ui, assumed, newval);
|
||||
} while (assumed != old);
|
||||
}
|
||||
|
||||
// It accumulates `val` into the `address` using the `func`.
|
||||
// This function is thread-safe (i.e., atomic).
|
||||
template<typename ValueType, typename BinaryFunc>
|
||||
__device__ __forceinline__ void atomic_binary_func(ValueType* address, ValueType val, BinaryFunc func) {
|
||||
ValueType observed = *address, assumed, new_value;
|
||||
using CasType = typename AtomicCasType<ValueType>::type;
|
||||
static_assert(sizeof(ValueType) == sizeof(CasType),
|
||||
"ValueType and CasType must have the same size for calling atomicCAS.");
|
||||
auto address_as_cas_type = reinterpret_cast<CasType*>(address);
|
||||
do {
|
||||
// Record the value used to compute new value.
|
||||
assumed = observed;
|
||||
|
||||
// Compute expected new value.
|
||||
new_value = func(observed, val);
|
||||
|
||||
// Cast to aribitrary 2-byte type to desired integer type supported by atomicCAS.
|
||||
// 4
|
||||
// 8
|
||||
auto observed_as_cas_type = *reinterpret_cast<CasType*>(&observed);
|
||||
auto new_value_as_cas_type = *reinterpret_cast<CasType*>(&new_value);
|
||||
|
||||
// Call atomicCAS as if the 2-byte type variables are all unsigned short int.
|
||||
// 4 unsigned int (or int)
|
||||
// 8 unsigned long long int
|
||||
auto cas_observed_as_cas_type = atomicCAS(address_as_cas_type, observed_as_cas_type, new_value_as_cas_type);
|
||||
|
||||
// Cast the freshly observed value in memory back to the TwoByteType.
|
||||
observed = *reinterpret_cast<ValueType*>(&cas_observed_as_cas_type);
|
||||
|
||||
// Two cases:
|
||||
// 1. compare-and-swap success
|
||||
// a. `address` holds `new_value`
|
||||
// b. `observed` becomes the new value after the assignment.
|
||||
// Thus, the following `observed != new_value` is false,
|
||||
// and the loop terminates.
|
||||
// 2. compare-and-swap fails
|
||||
// a. `address` holds a value different from `observed`, thus,
|
||||
// the `new_value` is stale.
|
||||
// b. `observed` becomes the fresh value observed in `address`.
|
||||
// Thus, the following (observed != new_value) is true,
|
||||
// and the loop continues. In the next iteration, the
|
||||
// `new_value` is computed again using the fresh `observed`.
|
||||
} while (observed != assumed);
|
||||
}
|
||||
|
||||
struct AddFunc {
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T operator()(T a, T b) const {
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
|
||||
struct MulFunc {
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T operator()(T a, T b) const {
|
||||
return a * b;
|
||||
}
|
||||
};
|
||||
|
||||
struct MaxFunc {
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T operator()(T a, T b) const {
|
||||
return b > a ? b : a;
|
||||
}
|
||||
};
|
||||
|
||||
struct MinFunc {
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T operator()(T a, T b) const {
|
||||
return b < a ? b : a;
|
||||
}
|
||||
};
|
||||
|
||||
__device__ __forceinline__ void atomic_add(int8_t* address, int8_t value) {
|
||||
atomic_byte_func_with_unit32_cas(address, value, AddFunc());
|
||||
}
|
||||
__device__ __forceinline__ void atomic_mul(int8_t* address, int8_t value) {
|
||||
atomic_byte_func_with_unit32_cas(address, value, MulFunc());
|
||||
}
|
||||
__device__ __forceinline__ void atomic_max(int8_t* address, int8_t value) {
|
||||
atomic_byte_func_with_unit32_cas(address, value, MaxFunc());
|
||||
}
|
||||
__device__ __forceinline__ void atomic_min(int8_t* address, int8_t value) {
|
||||
atomic_byte_func_with_unit32_cas(address, value, MinFunc());
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void atomic_mul(half* address, half value) {
|
||||
atomic_byte_func_with_unit32_cas(address, value, MulFunc());
|
||||
}
|
||||
__device__ __forceinline__ void atomic_max(half* address, half value) {
|
||||
atomic_byte_func_with_unit32_cas(address, value, MaxFunc());
|
||||
}
|
||||
__device__ __forceinline__ void atomic_min(half* address, half value) {
|
||||
atomic_byte_func_with_unit32_cas(address, value, MinFunc());
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void atomic_mul(float* address, float value) {
|
||||
atomic_binary_func(address, value, MulFunc());
|
||||
}
|
||||
__device__ __forceinline__ void atomic_max(float* address, float value) {
|
||||
atomic_binary_func(address, value, MaxFunc());
|
||||
}
|
||||
__device__ __forceinline__ void atomic_min(float* address, float value) {
|
||||
atomic_binary_func(address, value, MinFunc());
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void atomic_mul(double* address, double value) {
|
||||
atomic_binary_func(address, value, MulFunc());
|
||||
}
|
||||
__device__ __forceinline__ void atomic_max(double* address, double value) {
|
||||
atomic_binary_func(address, value, MaxFunc());
|
||||
}
|
||||
__device__ __forceinline__ void atomic_min(double* address, double value) {
|
||||
atomic_binary_func(address, value, MinFunc());
|
||||
}
|
||||
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -1069,7 +1069,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDom
|
|||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 14, Shape);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Size);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Transpose);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, ScatterElements);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 15, ScatterElements);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, Slice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, Slice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Softmax);
|
||||
|
|
@ -1290,6 +1290,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, float, LessOrEqual);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, double, LessOrEqual);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, 17, ScatterElements);
|
||||
|
||||
// Opset 17
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, float, LayerNormalization);
|
||||
|
|
@ -1302,7 +1303,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, Pad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, bool, Pad);
|
||||
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, ScatterElements);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, Split);
|
||||
|
||||
// Opset 19
|
||||
|
|
@ -2004,7 +2005,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 14, Shape)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Size)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Transpose)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, ScatterElements)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 15, ScatterElements)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Softmax)>,
|
||||
|
|
@ -2225,6 +2226,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, float, LessOrEqual)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, double, LessOrEqual)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, 17, ScatterElements)>,
|
||||
|
||||
// Opset 17
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, float, LayerNormalization)>,
|
||||
|
|
@ -2237,7 +2239,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, Pad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, bool, Pad)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, ScatterElements)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, Split)>,
|
||||
|
||||
// Opset 19
|
||||
|
|
|
|||
|
|
@ -302,5 +302,137 @@ TEST(Scatter, BoolInputWithAxis) {
|
|||
scatter_bool_with_axis_tests("ScatterElements", 11);
|
||||
}
|
||||
|
||||
TEST(ScatterElements, AddReduction) {
|
||||
OpTester test("ScatterElements", 18);
|
||||
test.AddAttribute<int64_t>("axis", 0);
|
||||
test.AddAttribute<std::string>("reduction", "add");
|
||||
|
||||
test.AddInput<float>("data", {2, 3}, {-9.f, -4.f, -1.f, -7.f, -3.f, -6.f});
|
||||
test.AddInput<int64_t>("indices", {4, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
|
||||
test.AddInput<float>("updates", {4, 3}, {1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f});
|
||||
test.AddOutput<float>("y", {2, 3}, {-9.f, -4.f, -1.f, -7.f + (1.f + 2.f + 3.f + 4.f), -3.f + (1.f + 2.f + 3.f + 4.f), -6.f + (1.f + 2.f + 3.f + 4.f)});
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(ScatterElements, AddReductionAxis1) {
|
||||
OpTester test("ScatterElements", 18);
|
||||
test.AddAttribute<int64_t>("axis", 1);
|
||||
test.AddAttribute<std::string>("reduction", "add");
|
||||
|
||||
// update's slice shape is {2, 1}
|
||||
test.AddInput<float>("data", {2, 3}, {9.f, 4.f, 1.f, 7.f, 3.f, 6.f});
|
||||
test.AddInput<int64_t>("indices", {2, 4}, {1, 1, 1, 1, 1, 1, 1, 1});
|
||||
test.AddInput<float>("updates", {2, 4}, {2.f, 5.f, 3.f, 6.f, 7.f, 9.f, 8.f, 10.f});
|
||||
test.AddOutput<float>("y", {2, 3}, {9.f, 4.f + (2.f + 5.f + 3.f + 6.f), 1.f, 7.f, 3.f + (7.f + 9.f + 8.f + 10.f), 6.f});
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(ScatterElements, MulReduction) {
|
||||
OpTester test("ScatterElements", 18);
|
||||
test.AddAttribute<int64_t>("axis", 0);
|
||||
test.AddAttribute<std::string>("reduction", "mul");
|
||||
|
||||
test.AddInput<float>("data", {2, 3}, {-9.f, -4.f, -1.f, -7.f, -3.f, -6.f});
|
||||
test.AddInput<int64_t>("indices", {2, 3}, {1, 1, 1, 1, 1, 1});
|
||||
test.AddInput<float>("updates", {2, 3}, {7.f, 3.f, 6.f, 7.f, 3.f, 6.f});
|
||||
test.AddOutput<float>("y", {2, 3}, {-9.f, -4.f, -1.f, -7.f * 7.f * 7.f, -3.f * 3.f * 3.f, -6.f * 6.f * 6.f});
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(ScatterElements, MulReductionAxis1) {
|
||||
OpTester test("ScatterElements", 18);
|
||||
test.AddAttribute<int64_t>("axis", 1);
|
||||
test.AddAttribute<std::string>("reduction", "mul");
|
||||
|
||||
// update's slice shape is {2, 1}
|
||||
test.AddInput<float>("data", {2, 3}, {9.f, 4.f, 1.f, 7.f, 3.f, 6.f});
|
||||
test.AddInput<int64_t>("indices", {2, 4}, {1, 1, 1, 1, 1, 1, 1, 1});
|
||||
test.AddInput<float>("updates", {2, 4}, {2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f});
|
||||
test.AddOutput<float>("y", {2, 3}, {9.f, 4.f * (2.f * 3.f * 4.f * 5.f), 1.f, 7.f, 3.f * (6.f * 7.f * 8.f * 9.f), 6.f});
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(ScatterElements, MaxReduction_MLFloat16) {
|
||||
OpTester test("ScatterElements", 18);
|
||||
test.AddAttribute<int64_t>("axis", 0);
|
||||
test.AddAttribute<std::string>("reduction", "max");
|
||||
|
||||
test.AddInput<MLFloat16>("data", {2, 3}, ToFloat16({-9.f, -4.f, -1.f, -7.f, -3.f, -6.f}));
|
||||
test.AddInput<int64_t>("indices", {2, 3}, {1, 1, 1, 1, 1, 1});
|
||||
test.AddInput<MLFloat16>("updates", {2, 3}, ToFloat16({1.f, 5.f, 3.f, 7.f, 3.f, 6.f}));
|
||||
test.AddOutput<MLFloat16>("y", {2, 3}, ToFloat16({-9.f, -4.f, -1.f, 7.f, 5.f, 6.f}));
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(ScatterElements, MaxReduction_Float) {
|
||||
OpTester test("ScatterElements", 18);
|
||||
test.AddAttribute<int64_t>("axis", 0);
|
||||
test.AddAttribute<std::string>("reduction", "max");
|
||||
|
||||
test.AddInput<float>("data", {2, 3}, {-9.f, -4.f, -1.f, -7.f, -3.f, -6.f});
|
||||
test.AddInput<int64_t>("indices", {2, 3}, {1, 1, 1, 1, 1, 1});
|
||||
test.AddInput<float>("updates", {2, 3}, {1.f, 5.f, 3.f, 7.f, 3.f, 6.f});
|
||||
test.AddOutput<float>("y", {2, 3}, {-9.f, -4.f, -1.f, 7.f, 5.f, 6.f});
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(ScatterElements, MaxReduction_Double) {
|
||||
OpTester test("ScatterElements", 18);
|
||||
test.AddAttribute<int64_t>("axis", 0);
|
||||
test.AddAttribute<std::string>("reduction", "max");
|
||||
|
||||
test.AddInput<double>("data", {2, 3}, {-9.f, -4.f, -1.f, -7.f, -3.f, -6.f});
|
||||
test.AddInput<int64_t>("indices", {2, 3}, {1, 1, 1, 1, 1, 1});
|
||||
test.AddInput<double>("updates", {2, 3}, {1.f, 5.f, 3.f, 7.f, 3.f, 6.f});
|
||||
test.AddOutput<double>("y", {2, 3}, {-9.f, -4.f, -1.f, 7.f, 5.f, 6.f});
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(ScatterElements, MinReduction_MLFloat16) {
|
||||
OpTester test("ScatterElements", 18);
|
||||
test.AddAttribute<int64_t>("axis", 0);
|
||||
test.AddAttribute<std::string>("reduction", "min");
|
||||
|
||||
test.AddInput<MLFloat16>("data", {2, 3}, ToFloat16({-9.f, -4.f, -1.f, 8.f, -3.f, 5.f}));
|
||||
test.AddInput<int64_t>("indices", {2, 3}, {1, 1, 1, 1, 1, 1});
|
||||
test.AddInput<MLFloat16>("updates", {2, 3}, ToFloat16({1.f, 5.f, 3.f, 7.f, 3.f, 6.f}));
|
||||
test.AddOutput<MLFloat16>("y", {2, 3}, ToFloat16({-9.f, -4.f, -1.f, 1.f, -3.f, 3.f}));
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(ScatterElements, MinReduction_Float) {
|
||||
OpTester test("ScatterElements", 18);
|
||||
test.AddAttribute<int64_t>("axis", 0);
|
||||
test.AddAttribute<std::string>("reduction", "min");
|
||||
|
||||
test.AddInput<float>("data", {2, 3}, {-9.f, -4.f, -1.f, 8.f, -3.f, 5.f});
|
||||
test.AddInput<int64_t>("indices", {2, 3}, {1, 1, 1, 1, 1, 1});
|
||||
test.AddInput<float>("updates", {2, 3}, {1.f, 5.f, 3.f, 7.f, 3.f, 6.f});
|
||||
test.AddOutput<float>("y", {2, 3}, {-9.f, -4.f, -1.f, 1.f, -3.f, 3.f});
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(ScatterElements, MinReduction_Double) {
|
||||
OpTester test("ScatterElements", 18);
|
||||
test.AddAttribute<int64_t>("axis", 0);
|
||||
test.AddAttribute<std::string>("reduction", "min");
|
||||
|
||||
test.AddInput<double>("data", {2, 3}, {-9.f, -4.f, -1.f, 8.f, -3.f, 5.f});
|
||||
test.AddInput<int64_t>("indices", {2, 3}, {1, 1, 1, 1, 1, 1});
|
||||
test.AddInput<double>("updates", {2, 3}, {1.f, 5.f, 3.f, 7.f, 3.f, 6.f});
|
||||
test.AddOutput<double>("y", {2, 3}, {-9.f, -4.f, -1.f, 1.f, -3.f, 3.f});
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue