mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-27 03:11:28 +00:00
register softmaxinternal with rocm (#8289)
This commit is contained in:
parent
969eb545d1
commit
036eee5b66
1 changed files with 8 additions and 0 deletions
|
|
@ -57,6 +57,10 @@ class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDom
|
|||
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, int64_t, SoftmaxCrossEntropyLoss);
|
||||
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, int64_t, SoftmaxCrossEntropyLossGrad);
|
||||
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, int64_t, SoftmaxCrossEntropyLossGrad);
|
||||
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, int64_t, SoftmaxCrossEntropyLossInternal);
|
||||
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, int64_t, SoftmaxCrossEntropyLossInternal);
|
||||
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, int64_t, SoftmaxCrossEntropyLossInternalGrad);
|
||||
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, int64_t, SoftmaxCrossEntropyLossInternalGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SoftmaxGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, SoftmaxGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SoftmaxGrad);
|
||||
|
|
@ -201,6 +205,10 @@ Status RegisterRocmTrainingKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, int64_t, SoftmaxCrossEntropyLoss)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, int64_t, SoftmaxCrossEntropyLossGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, int64_t, SoftmaxCrossEntropyLossGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, int64_t, SoftmaxCrossEntropyLossInternal)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, int64_t, SoftmaxCrossEntropyLossInternal)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, int64_t, SoftmaxCrossEntropyLossInternalGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, int64_t, SoftmaxCrossEntropyLossInternalGrad)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, BatchNormalizationGrad)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, BatchNormalizationGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, GatherGrad)>,
|
||||
|
|
|
|||
Loading…
Reference in a new issue