mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-28 22:56:32 +00:00
SoftmaxCrossEntropyLoss OpSet13. (#5777)
Co-authored-by: Vincent Wang <weicwang@microsoft.com>
This commit is contained in:
parent
b92fc66ea1
commit
2a87108431
2 changed files with 26 additions and 6 deletions
|
|
@ -29,8 +29,10 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Softm
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SoftmaxCrossEntropyGrad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, SparseSoftmaxCrossEntropy);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, SparseSoftmaxCrossEntropyGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, float_int32_t, SoftmaxCrossEntropyLoss);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, float_int64_t, SoftmaxCrossEntropyLoss);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, float_int32_t, SoftmaxCrossEntropyLoss);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, float_int64_t, SoftmaxCrossEntropyLoss);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float_int32_t, SoftmaxCrossEntropyLoss);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float_int64_t, SoftmaxCrossEntropyLoss);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float_int32_t, SoftmaxCrossEntropyLossGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float_int64_t, SoftmaxCrossEntropyLossGrad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, SinGrad);
|
||||
|
|
@ -139,8 +141,10 @@ Status RegisterCpuTrainingKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SoftmaxCrossEntropyGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, SparseSoftmaxCrossEntropy)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, SparseSoftmaxCrossEntropyGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, float_int32_t, SoftmaxCrossEntropyLoss)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, float_int64_t, SoftmaxCrossEntropyLoss)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, float_int32_t, SoftmaxCrossEntropyLoss)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, float_int64_t, SoftmaxCrossEntropyLoss)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float_int32_t, SoftmaxCrossEntropyLoss)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float_int64_t, SoftmaxCrossEntropyLoss)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float_int32_t, SoftmaxCrossEntropyLossGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float_int64_t, SoftmaxCrossEntropyLossGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, SinGrad)>,
|
||||
|
|
|
|||
|
|
@ -16,6 +16,20 @@
|
|||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
#define REGISTER_KERNEL_VERSIONED_TYPED(OpName, Domain, StartVer, EndVer, T1, T2) \
|
||||
ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_EX( \
|
||||
OpName, \
|
||||
Domain, \
|
||||
StartVer, \
|
||||
EndVer, \
|
||||
T1, \
|
||||
T2, \
|
||||
kCpuExecutionProvider, \
|
||||
KernelDefBuilder() \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T1>()) \
|
||||
.TypeConstraint("Tind", DataTypeImpl::GetTensorType<T2>()), \
|
||||
OpName<T1, T2>);
|
||||
|
||||
#define REGISTER_KERNEL_TYPED(OpName, Domain, VER, T1, T2) \
|
||||
ONNX_OPERATOR_TWO_TYPED_KERNEL_EX( \
|
||||
OpName, \
|
||||
|
|
@ -29,8 +43,10 @@ namespace contrib {
|
|||
.TypeConstraint("Tind", DataTypeImpl::GetTensorType<T2>()), \
|
||||
OpName<T1, T2>);
|
||||
|
||||
REGISTER_KERNEL_TYPED(SoftmaxCrossEntropyLoss, kOnnxDomain, 12, float, int32_t)
|
||||
REGISTER_KERNEL_TYPED(SoftmaxCrossEntropyLoss, kOnnxDomain, 12, float, int64_t)
|
||||
REGISTER_KERNEL_VERSIONED_TYPED(SoftmaxCrossEntropyLoss, kOnnxDomain, 12, 12, float, int32_t)
|
||||
REGISTER_KERNEL_VERSIONED_TYPED(SoftmaxCrossEntropyLoss, kOnnxDomain, 12, 12, float, int64_t)
|
||||
REGISTER_KERNEL_TYPED(SoftmaxCrossEntropyLoss, kOnnxDomain, 13, float, int32_t)
|
||||
REGISTER_KERNEL_TYPED(SoftmaxCrossEntropyLoss, kOnnxDomain, 13, float, int64_t)
|
||||
|
||||
void GetNDCFromLogitAndLabelShape(const TensorShape& logit_shape, const TensorShape& label_shape, int64_t& N_D, int64_t& C) {
|
||||
// N_D = N * D1 * D2...D*K
|
||||
|
|
|
|||
Loading…
Reference in a new issue