mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-01 23:30:35 +00:00
Enable more unit tests for ROCM EP (#6776)
* enable more ops and unit tests for ROCM EP
This commit is contained in:
parent
f4acdb2ecd
commit
40fa40f3ce
6 changed files with 317 additions and 119 deletions
|
|
@ -674,7 +674,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, MLFloat16, ArgMin);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, Compress);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, Concat);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, Flatten);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, Flatten);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, Gather);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, GatherElements);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, float, Gemm);
|
||||
|
|
@ -767,7 +767,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, uint8_t, QuantizeLinear);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, int8_t, DequantizeLinear);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, uint8_t, DequantizeLinear);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, CumSum);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 13, CumSum);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, int64_t_int64_t_int64_t, OneHot);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, int64_t_float_int64_t, OneHot);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, int32_t_float_int32_t, OneHot);
|
||||
|
|
@ -775,7 +775,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, int32_t_MLFloat16_int32_t, OneHot);
|
||||
|
||||
// OpSet 12
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, Clip);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, Clip);
|
||||
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, float, MaxPool);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, double, MaxPool);
|
||||
|
|
@ -989,6 +989,19 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceSumSquare);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, GatherND);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Dropout);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Resize);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Resize);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Resize);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, Resize);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint8_t, Resize);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, If);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Loop);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Flatten);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, LRN);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, LRN);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, LRN);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Identity);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, ScatterND);
|
||||
|
||||
template <>
|
||||
KernelCreateInfo BuildKernelCreateInfo<void>() {
|
||||
|
|
@ -1252,9 +1265,9 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, uint32_t, Cast)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, uint64_t, Cast)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, bool, Cast)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 2, 10, float, Pad)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 2, 10, double, Pad)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 2, 10, MLFloat16, Pad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 2, 10, float, Pad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 2, 10, double, Pad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 2, 10, MLFloat16, Pad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 4, Reshape)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 5, 12, Reshape)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 12, Shape)>,
|
||||
|
|
@ -1277,26 +1290,26 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 9, int32_t, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 9, int64_t, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 9, float, Slice)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 10, Compress)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 10, Flatten)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 10, Compress)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 10, Flatten)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 9, float, Upsample)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 9, double, Upsample)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 9, MLFloat16, Upsample)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 9, int32_t, Upsample)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 9, uint8_t, Upsample)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 2, 10, Split)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, ConstantOfShape)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, int8_t, Shrink)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, int16_t, Shrink)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, int32_t, Shrink)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, int64_t, Shrink)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, uint8_t, Shrink)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, uint16_t, Shrink)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, uint32_t, Shrink)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, uint64_t, Shrink)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, float, Shrink)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, double, Shrink)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, MLFloat16, Shrink)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 2, 10, Split)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, ConstantOfShape)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, int8_t, Shrink)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, int16_t, Shrink)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, int32_t, Shrink)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, int64_t, Shrink)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, uint8_t, Shrink)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, uint16_t, Shrink)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, uint32_t, Shrink)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, uint64_t, Shrink)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, float, Shrink)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, double, Shrink)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, MLFloat16, Shrink)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, float, Less)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, double, Less)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, Less)>,
|
||||
|
|
@ -1307,7 +1320,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, float, Less)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, double, Less)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, MLFloat16, Less)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, EyeLike)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, EyeLike)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 10, Scatter)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, MLFloat16, Where)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, float, Where)>,
|
||||
|
|
@ -1328,7 +1341,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
|
|||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, float, AveragePool)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, double, AveragePool)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, MLFloat16, AveragePool)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 11, Dropout)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 11, Dropout)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, float, MaxPool)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, double, MaxPool)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, MLFloat16, MaxPool)>,
|
||||
|
|
@ -1338,12 +1351,12 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
|
|||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, MLFloat16, Resize)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, int32_t, Resize)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, uint8_t, Resize)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, ReverseSequence)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, ReverseSequence)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, float, RoiAlign)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, double, RoiAlign)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, int32_t, Slice)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, int64_t, Slice)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, float, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, int32_t, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, int64_t, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, float, Slice)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, float, ThresholdedRelu)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, double, ThresholdedRelu)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, MLFloat16, ThresholdedRelu)>,
|
||||
|
|
@ -1361,9 +1374,9 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, float, ArgMin)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, double, ArgMin)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, MLFloat16, ArgMin)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, Compress)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, Compress)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, Concat)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, Flatten)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, Flatten)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, Gather)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, GatherElements)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, Gemm)>,
|
||||
|
|
@ -1372,7 +1385,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
|
|||
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, If)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, Loop)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, NonMaxSuppression)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, Range)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, Range)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, float, ReduceL1)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, double, ReduceL1)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceL1)>,
|
||||
|
|
@ -1411,7 +1424,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
|
|||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, double, ReduceSumSquare)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceSumSquare)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, Scan)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, ScatterElements)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, ScatterElements)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, int32_t, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, int64_t, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, float, Slice)>,
|
||||
|
|
@ -1442,17 +1455,17 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
|
|||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, MLFloat16, Resize)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, int32_t, Resize)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, uint8_t, Resize)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, Clip)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, float, Pad)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, double, Pad)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, MLFloat16, Pad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, Clip)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, float, Pad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, double, Pad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, MLFloat16, Pad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, bool, Equal)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, int32_t, Equal)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, int64_t, Equal)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, float, Round)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, double, Round)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, MLFloat16, Round)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, CumSum)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, float, Round)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, double, Round)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, MLFloat16, Round)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 13, CumSum)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, int64_t_int64_t_int64_t, OneHot)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, int64_t_float_int64_t, OneHot)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, int32_t_float_int32_t, OneHot)>,
|
||||
|
|
@ -1460,7 +1473,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
|
|||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, int32_t_MLFloat16_int32_t, OneHot)>,
|
||||
|
||||
// OpSet 12
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, Clip)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, Clip)>,
|
||||
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, float, MaxPool)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, double, MaxPool)>,
|
||||
|
|
@ -1674,6 +1687,19 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceSumSquare)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, GatherND)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Dropout)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Resize)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Resize)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Resize)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, Resize)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint8_t, Resize)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, If)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Loop)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Flatten)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, LRN)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, LRN)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, LRN)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Identity)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, ScatterND)>,
|
||||
};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "div_grad.h"
|
||||
#include "div_grad_impl.h"
|
||||
#include "orttraining/training_ops/cuda/math/div_grad.h"
|
||||
#include "orttraining/training_ops/cuda/math/div_grad_impl.h"
|
||||
#include "core/providers/cuda/math/binary_elementwise_ops.h"
|
||||
|
||||
using namespace onnxruntime::common;
|
||||
|
|
|
|||
244
orttraining/orttraining/training_ops/rocm/math/div_grad.cc
Normal file
244
orttraining/orttraining/training_ops/rocm/math/div_grad.cc
Normal file
|
|
@ -0,0 +1,244 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "orttraining/training_ops/rocm/math/div_grad.h"
|
||||
#include "orttraining/training_ops/rocm/math/div_grad_impl.h"
|
||||
#include "core/providers/rocm/math/binary_elementwise_ops.h"
|
||||
|
||||
using namespace onnxruntime::common;
|
||||
namespace onnxruntime {
|
||||
namespace rocm {
|
||||
|
||||
#define DIVGRAD_REGISTER_KERNEL_TYPED(T) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
DivGrad, \
|
||||
kMSDomain, \
|
||||
1, \
|
||||
T, \
|
||||
kRocmExecutionProvider, \
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
DivGrad<T>);
|
||||
|
||||
DIVGRAD_REGISTER_KERNEL_TYPED(MLFloat16)
|
||||
DIVGRAD_REGISTER_KERNEL_TYPED(float)
|
||||
// DIVGRAD_REGISTER_KERNEL_TYPED(double)
|
||||
|
||||
std::vector<int64_t> prepended_dimension_1(const TensorShape& shape, size_t total_rank) {
|
||||
size_t input_rank = shape.NumDimensions();
|
||||
if (input_rank == total_rank)
|
||||
return shape.GetDims();
|
||||
|
||||
std::vector<int64_t> dims(total_rank, 1);
|
||||
|
||||
// https://github.com/onnx/onnx/blob/master/docs/Broadcasting.md
|
||||
// for property 3 of Multidirectional Broadcasting, we need to prepended with a dimension of length 1.
|
||||
if (input_rank > 0)
|
||||
std::copy(shape.GetDims().begin(), shape.GetDims().end(), &dims[total_rank - input_rank]);
|
||||
return dims;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status DivGrad<T>::ComputeInternal(OpKernelContext* context) const {
|
||||
typedef typename ToHipType<T>::MappedType HipT;
|
||||
|
||||
const Tensor* dy_tensor = context->Input<Tensor>(0);
|
||||
const Tensor* a_tensor = context->Input<Tensor>(1);
|
||||
const Tensor* b_tensor = context->Input<Tensor>(2);
|
||||
const TensorShape& a_shape = a_tensor->Shape();
|
||||
const TensorShape& b_shape = b_tensor->Shape();
|
||||
const TensorShape& dy_shape = dy_tensor->Shape();
|
||||
|
||||
// output shapes shall match its corresponding inputs
|
||||
Tensor* da_output_tensor = context->Output(0, a_shape);
|
||||
Tensor* db_output_tensor = context->Output(1, b_shape);
|
||||
if (!da_output_tensor && !db_output_tensor)
|
||||
return Status::OK();
|
||||
|
||||
BinaryElementwisePreparation prepare;
|
||||
ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(a_tensor, b_tensor,
|
||||
// TODO: BinaryElementwiseBroadcastPrepare shall take dy_tensor as const Tensor*.
|
||||
const_cast<Tensor*>(dy_tensor), &prepare));
|
||||
const HipT* prepare_a_data = reinterpret_cast<const HipT*>(prepare.lhs_tensor->template Data<T>());
|
||||
const HipT* prepare_b_data = reinterpret_cast<const HipT*>(prepare.rhs_tensor->template Data<T>());
|
||||
const HipT* prepare_dy_data = reinterpret_cast<const HipT*>(prepare.output_tensor->template Data<T>());
|
||||
T* da_data = da_output_tensor ? da_output_tensor->template MutableData<T>() : nullptr;
|
||||
T* db_data = db_output_tensor ? db_output_tensor->template MutableData<T>() : nullptr;
|
||||
|
||||
switch (prepare.output_rank_or_simple_broadcast) {
|
||||
case static_cast<int32_t>(SimpleBroadcast::NoBroadcast):
|
||||
ImplDivGradSimple<HipT>(
|
||||
Stream(),
|
||||
SimpleBroadcast::NoBroadcast,
|
||||
prepare_a_data,
|
||||
prepare_b_data,
|
||||
prepare_dy_data,
|
||||
dy_shape.Size(),
|
||||
reinterpret_cast<HipT*>(da_data),
|
||||
reinterpret_cast<HipT*>(db_data));
|
||||
break;
|
||||
case static_cast<int32_t>(SimpleBroadcast::LeftScalar): {
|
||||
T* temp_da_data = nullptr;
|
||||
IAllocatorUniquePtr<T> temp_da_allocator;
|
||||
if (da_output_tensor) {
|
||||
temp_da_allocator = GetScratchBuffer<T>(dy_shape.Size());
|
||||
temp_da_data = temp_da_allocator.get();
|
||||
}
|
||||
|
||||
ImplDivGradSimple<HipT>(
|
||||
Stream(),
|
||||
SimpleBroadcast::LeftScalar,
|
||||
prepare_a_data,
|
||||
prepare_b_data,
|
||||
prepare_dy_data,
|
||||
dy_shape.Size(),
|
||||
reinterpret_cast<HipT*>(temp_da_data),
|
||||
reinterpret_cast<HipT*>(db_data));
|
||||
|
||||
if (da_output_tensor) {
|
||||
std::vector<int64_t> a_output_dims = prepended_dimension_1(a_shape, dy_shape.NumDimensions());
|
||||
ReduceKernelShared<T, T, MIOPEN_REDUCE_TENSOR_NO_INDICES>(
|
||||
temp_da_data,
|
||||
dy_shape,
|
||||
da_data,
|
||||
TensorShape({}),
|
||||
MIOPEN_REDUCE_TENSOR_ADD,
|
||||
a_output_dims);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case static_cast<int32_t>(SimpleBroadcast::RightScalar): {
|
||||
T* temp_db_data = nullptr;
|
||||
IAllocatorUniquePtr<T> temp_db_allocator;
|
||||
if (db_output_tensor) {
|
||||
temp_db_allocator = GetScratchBuffer<T>(dy_shape.Size());
|
||||
temp_db_data = temp_db_allocator.get();
|
||||
}
|
||||
ImplDivGradSimple<HipT>(
|
||||
Stream(),
|
||||
SimpleBroadcast::RightScalar,
|
||||
prepare_a_data,
|
||||
prepare_b_data,
|
||||
prepare_dy_data,
|
||||
dy_shape.Size(),
|
||||
reinterpret_cast<HipT*>(da_data),
|
||||
reinterpret_cast<HipT*>(temp_db_data));
|
||||
|
||||
if (db_output_tensor) {
|
||||
std::vector<int64_t> b_output_dims = prepended_dimension_1(b_shape, dy_shape.NumDimensions());
|
||||
ReduceKernelShared<T, T, MIOPEN_REDUCE_TENSOR_NO_INDICES>(
|
||||
temp_db_data,
|
||||
dy_shape,
|
||||
db_data,
|
||||
TensorShape({}),
|
||||
MIOPEN_REDUCE_TENSOR_ADD,
|
||||
b_output_dims);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case static_cast<int32_t>(SimpleBroadcast::RightPerChannelBatch1):
|
||||
case static_cast<int32_t>(SimpleBroadcast::RightPerChannelBatchN): {
|
||||
T* temp_db_data = nullptr;
|
||||
IAllocatorUniquePtr<T> temp_db_allocator;
|
||||
if (db_output_tensor) {
|
||||
temp_db_allocator = GetScratchBuffer<T>(dy_shape.Size());
|
||||
temp_db_data = temp_db_allocator.get();
|
||||
}
|
||||
if (prepare.output_rank_or_simple_broadcast == static_cast<int32_t>(SimpleBroadcast::RightPerChannelBatch1)) {
|
||||
// lhs(1,C,H) and rhs (C,1)
|
||||
ImplDivGradRhsPerChannelBatch1<HipT>(
|
||||
Stream(),
|
||||
prepare_a_data,
|
||||
prepare_b_data,
|
||||
prepare_dy_data,
|
||||
dy_shape.Size(),
|
||||
prepare.fdm_H,
|
||||
reinterpret_cast<HipT*>(da_data),
|
||||
reinterpret_cast<HipT*>(temp_db_data));
|
||||
} else {
|
||||
// lhs(N,C,H) and rhs (C,1)
|
||||
ImplDivGradRhsPerChannelBatchN<HipT>(
|
||||
Stream(),
|
||||
prepare_a_data,
|
||||
prepare_b_data,
|
||||
prepare_dy_data,
|
||||
dy_shape.Size(),
|
||||
prepare.fdm_H,
|
||||
prepare.fdm_C,
|
||||
reinterpret_cast<HipT*>(da_data),
|
||||
reinterpret_cast<HipT*>(temp_db_data));
|
||||
}
|
||||
|
||||
if (db_output_tensor) {
|
||||
std::vector<int64_t> b_output_dims = prepended_dimension_1(b_shape, dy_shape.NumDimensions());
|
||||
ReduceKernelShared<T, T, MIOPEN_REDUCE_TENSOR_NO_INDICES>(
|
||||
temp_db_data,
|
||||
dy_shape,
|
||||
db_data,
|
||||
b_shape,
|
||||
MIOPEN_REDUCE_TENSOR_ADD,
|
||||
b_output_dims);
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
bool need_reduce_da = da_output_tensor && a_shape.Size() != dy_shape.Size();
|
||||
bool need_reduce_db = db_output_tensor && b_shape.Size() != dy_shape.Size();
|
||||
IAllocatorUniquePtr<T> temp_da_allocator, temp_db_allocator;
|
||||
T* da_data_ref = nullptr;
|
||||
if (da_output_tensor)
|
||||
if (need_reduce_da) {
|
||||
temp_da_allocator = GetScratchBuffer<T>(dy_shape.Size());
|
||||
da_data_ref = temp_da_allocator.get();
|
||||
} else {
|
||||
da_data_ref = da_data;
|
||||
}
|
||||
T* db_data_ref = nullptr;
|
||||
if (db_output_tensor)
|
||||
if (need_reduce_db) {
|
||||
temp_db_allocator = GetScratchBuffer<T>(dy_shape.Size());
|
||||
db_data_ref = temp_db_allocator.get();
|
||||
} else {
|
||||
db_data_ref = db_data;
|
||||
}
|
||||
|
||||
ImplDivGrad<HipT>(
|
||||
Stream(),
|
||||
prepare.output_rank_or_simple_broadcast,
|
||||
&prepare.lhs_padded_strides,
|
||||
prepare_a_data,
|
||||
&prepare.rhs_padded_strides,
|
||||
prepare_b_data,
|
||||
prepare_dy_data,
|
||||
dy_shape.Size(),
|
||||
&prepare.fdm_output_strides,
|
||||
reinterpret_cast<HipT*>(da_data_ref),
|
||||
reinterpret_cast<HipT*>(db_data_ref));
|
||||
|
||||
if (need_reduce_da) {
|
||||
std::vector<int64_t> a_output_dims = prepended_dimension_1(a_shape, dy_shape.NumDimensions());
|
||||
ReduceKernelShared<T, T, MIOPEN_REDUCE_TENSOR_NO_INDICES>(
|
||||
da_data_ref,
|
||||
dy_shape,
|
||||
da_data,
|
||||
a_shape,
|
||||
MIOPEN_REDUCE_TENSOR_ADD,
|
||||
a_output_dims);
|
||||
}
|
||||
|
||||
if (need_reduce_db) {
|
||||
std::vector<int64_t> b_output_dims = prepended_dimension_1(b_shape, dy_shape.NumDimensions());
|
||||
ReduceKernelShared<T, T, MIOPEN_REDUCE_TENSOR_NO_INDICES>(
|
||||
db_data_ref,
|
||||
dy_shape,
|
||||
db_data,
|
||||
b_shape,
|
||||
MIOPEN_REDUCE_TENSOR_ADD,
|
||||
b_output_dims);
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -197,9 +197,9 @@ Status RegisterRocmTrainingKernels(KernelRegistry& kernel_registry) {
|
|||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, BatchNormalizationGrad)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, BatchNormalizationGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, GatherGrad)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DivGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DivGrad)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, DivGrad)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DivGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DivGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GeluGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, GeluGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GeluGrad)>,
|
||||
|
|
@ -234,7 +234,7 @@ Status RegisterRocmTrainingKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double_double, InvertibleLayerNormalizationGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_float, InvertibleLayerNormalizationGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, SliceGrad)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, GatherElementsGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, GatherElementsGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Scale)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Scale)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, Scale)>,
|
||||
|
|
|
|||
|
|
@ -79,20 +79,10 @@ provider_excluded_files = [
|
|||
'controlflow/scan.cc',
|
||||
'controlflow/scan.h',
|
||||
'cu_inc/common.cuh',
|
||||
'generator/constant_of_shape.cc',
|
||||
'generator/constant_of_shape.h',
|
||||
'generator/range.cc',
|
||||
'generator/range.h',
|
||||
'generator/range_impl.cu',
|
||||
'generator/range_impl.h',
|
||||
'math/einsum_utils/einsum_auxiliary_ops.cc',
|
||||
'math/einsum_utils/einsum_auxiliary_ops.h',
|
||||
'math/einsum_utils/einsum_auxiliary_ops_diagonal.cu',
|
||||
'math/einsum_utils/einsum_auxiliary_ops_diagonal.h',
|
||||
'math/cumsum.cc',
|
||||
'math/cumsum.h',
|
||||
'math/cumsum_impl.cu',
|
||||
'math/cumsum_impl.h',
|
||||
'math/einsum.cc',
|
||||
'math/einsum.h',
|
||||
'math/gemm.cc',
|
||||
|
|
@ -123,10 +113,6 @@ provider_excluded_files = [
|
|||
'nn/max_pool_with_index.h',
|
||||
'nn/pool.cc',
|
||||
'nn/pool.h',
|
||||
'nn/shrink.cc',
|
||||
'nn/shrink.h',
|
||||
'nn/shrink_impl.cu',
|
||||
'nn/shrink_impl.h',
|
||||
'object_detection/non_max_suppression.cc',
|
||||
'object_detection/non_max_suppression.h',
|
||||
'object_detection/non_max_suppression_impl.cu',
|
||||
|
|
@ -151,25 +137,7 @@ provider_excluded_files = [
|
|||
'shared_inc/fast_divmod.h',
|
||||
'shared_inc/fpgeneric.h',
|
||||
'shared_inc/integer_gemm.h',
|
||||
'tensor/compress.cc',
|
||||
'tensor/compress.h',
|
||||
'tensor/compress_impl.cu',
|
||||
'tensor/compress_impl.h',
|
||||
'tensor/eye_like.cc',
|
||||
'tensor/eye_like.h',
|
||||
'tensor/eye_like_impl.cu',
|
||||
'tensor/eye_like_impl.h',
|
||||
'tensor/flatten.cc',
|
||||
'tensor/flatten.h',
|
||||
'tensor/gather_elements.cc',
|
||||
'tensor/gather_elements.h',
|
||||
'tensor/gather_elements_impl.cu',
|
||||
'tensor/gather_elements_impl.h',
|
||||
'tensor/gather_nd_impl.cu',
|
||||
'tensor/pad.cc',
|
||||
'tensor/pad.h',
|
||||
'tensor/pad_impl.cu',
|
||||
'tensor/pad_impl.h',
|
||||
'tensor/quantize_linear.cc',
|
||||
'tensor/quantize_linear.cu',
|
||||
'tensor/quantize_linear.cuh',
|
||||
|
|
@ -178,10 +146,6 @@ provider_excluded_files = [
|
|||
'tensor/resize.h',
|
||||
'tensor/resize_impl.cu',
|
||||
'tensor/resize_impl.h',
|
||||
'tensor/reverse_sequence.cc',
|
||||
'tensor/reverse_sequence.h',
|
||||
'tensor/reverse_sequence_impl.cu',
|
||||
'tensor/reverse_sequence_impl.h',
|
||||
'tensor/transpose.cc',
|
||||
'tensor/transpose.h',
|
||||
'tensor/upsample.cc',
|
||||
|
|
@ -234,9 +198,6 @@ training_ops_excluded_files = [
|
|||
'controlflow/wait.cc',
|
||||
'controlflow/wait.h',
|
||||
'math/div_grad.cc',
|
||||
'math/div_grad.h',
|
||||
'math/div_grad_impl.cu',
|
||||
'math/div_grad_impl.h',
|
||||
'math/softmax_grad_impl.cu',
|
||||
'math/softmax_grad.cc',
|
||||
'nn/batch_norm_grad.cc',
|
||||
|
|
@ -246,8 +207,6 @@ training_ops_excluded_files = [
|
|||
'optimizer/lamb.cc',
|
||||
'reduction/reduction_all.cc',
|
||||
'reduction/reduction_ops.cc',
|
||||
'tensor/gather_elements_grad.cc',
|
||||
'tensor/gather_elements_grad.h',
|
||||
'tensor/gather_grad.cc',
|
||||
'tensor/gather_grad_impl.cu',
|
||||
'tensor/gather_grad_impl.h',
|
||||
|
|
|
|||
|
|
@ -12,13 +12,9 @@ OptimizerTest.LambOptimizerTestBaselineMixPrecision32_16
|
|||
OptimizerTest.LambOptimizerTestScalarMixPrecision32_16
|
||||
OptimizerTest.LambOptimizerTestScalarMixPrecision32_16_NoDefaultMaxNormClipping
|
||||
OptimizerTest.LambOptimizerTestLarge
|
||||
CudaKernelTest.SoftmaxCrossEntropy_TinySizeTensor
|
||||
CudaKernelTest.SoftmaxCrossEntropy_SmallSizeTensor
|
||||
CudaKernelTest.SoftmaxCrossEntropy_MediumSizeTensor
|
||||
CudaKernelTest.SoftmaxCrossEntropy_LargeSizeTensor
|
||||
CudaKernelTest.SparseSoftmaxCrossEntropy_LargeSizeTensor
|
||||
CudaKernelTest.NegativeLogLikelihoodLoss_TinySizeTensor
|
||||
CudaKernelTest.NegativeLogLikelihoodLoss_SmallSizeTensor
|
||||
CudaKernelTest.NegativeLogLikelihoodLoss_TinySizeTensor
|
||||
CudaKernelTest.NegativeLogLikelihoodLoss_SmallSizeTensor
|
||||
CudaKernelTest.NegativeLogLikelihoodLoss_MediumSizeTensor
|
||||
ReductionOpTest.ReductionVariationTest
|
||||
ReductionOpTest.ReduceL1_default_axes_keepdims
|
||||
|
|
@ -76,39 +72,12 @@ ReductionOpTest.ReduceInfLogSumExp_double
|
|||
GatherOpTest.Gather_invalid_index_cpu
|
||||
Scatter.InvalidIndex
|
||||
LogSoftmaxOperator.LargeNumber
|
||||
GatherElementsGrad.WithoutAxis
|
||||
GatherElementsGrad.WithAxis
|
||||
GatherElementsGrad.ThreeDimsWithAxis_0
|
||||
GatherElementsGrad.ThreeDimsWithAxis_2
|
||||
GatherElementsGrad.NegativeAxis
|
||||
GatherElementsGrad.IndicesUpdatesDontMatch
|
||||
GatherElementsGrad.ValidAxis
|
||||
GatherElementsGrad.ValidNegativeIndex
|
||||
GatherElementsGrad.SameUpdateWithoutAxis
|
||||
GatherElementsGrad.SameUpdateWithAxis
|
||||
GatherElementsGrad.SameUpdateWithNegativeAxis
|
||||
GatherElementsGrad.SameUpdateWithoutAxisMLFloat16
|
||||
MathOpTest.Max_8_2inputbroadcast
|
||||
MathOpTest.Less_broadcastBA
|
||||
MathOpTest.Less_multidiretional_broadcastAB
|
||||
MathOpTest.Less_multidiretional_broadcastBA
|
||||
MathOpTest.Greater_broadcastBA
|
||||
MathOpTest.Greater_multidiretional_broadcastAB
|
||||
MathOpTest.Greater_multidiretional_broadcastBA
|
||||
MathOpTest.Pow_int64_float
|
||||
MathOpTest.Pow_int32_float
|
||||
MathOpTest.Pow_int64_double
|
||||
GradientCheckerTest.AddGrad
|
||||
GradientCheckerTest.SubGrad
|
||||
GradientCheckerTest.MulGrad
|
||||
GradientCheckerTest.MatMulGrad
|
||||
GradientCheckerTest.ReduceMeanGrad
|
||||
GradientCheckerTest.ReduceL2Grad
|
||||
GradientCheckerTest.SoftmaxCrossEntropyGrad
|
||||
GradientCheckerTest.ExpandGrad
|
||||
GradientCheckerTest.DivGrad
|
||||
GradientCheckerTest.GemmGrad
|
||||
GradientCheckerTest.SplitGrad
|
||||
GradientCheckerTest.SqueezeGrad
|
||||
GradientCheckerTest.UnsqueezeGrad
|
||||
GradientCheckerTest.ClipGrad
|
||||
|
|
|
|||
Loading…
Reference in a new issue