diff --git a/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc b/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc index 2a0f0f48be..239db66cf4 100644 --- a/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc @@ -29,8 +29,10 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Softm class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SoftmaxCrossEntropyGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, SparseSoftmaxCrossEntropy); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, SparseSoftmaxCrossEntropyGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, float_int32_t, SoftmaxCrossEntropyLoss); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, float_int64_t, SoftmaxCrossEntropyLoss); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, float_int32_t, SoftmaxCrossEntropyLoss); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, float_int64_t, SoftmaxCrossEntropyLoss); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float_int32_t, SoftmaxCrossEntropyLoss); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float_int64_t, SoftmaxCrossEntropyLoss); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float_int32_t, SoftmaxCrossEntropyLossGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float_int64_t, SoftmaxCrossEntropyLossGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, SinGrad); @@ -139,8 +141,10 @@ Status RegisterCpuTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/orttraining/orttraining/training_ops/cpu/loss/softmax_cross_entropy_loss.cc b/orttraining/orttraining/training_ops/cpu/loss/softmax_cross_entropy_loss.cc index 11142aebdf..f182dff7d0 100644 --- a/orttraining/orttraining/training_ops/cpu/loss/softmax_cross_entropy_loss.cc +++ b/orttraining/orttraining/training_ops/cpu/loss/softmax_cross_entropy_loss.cc @@ -16,6 +16,20 @@ namespace onnxruntime { namespace contrib { +#define REGISTER_KERNEL_VERSIONED_TYPED(OpName, Domain, StartVer, EndVer, T1, T2) \ + ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_EX( \ + OpName, \ + Domain, \ + StartVer, \ + EndVer, \ + T1, \ + T2, \ + kCpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("Tind", DataTypeImpl::GetTensorType()), \ + OpName); + #define REGISTER_KERNEL_TYPED(OpName, Domain, VER, T1, T2) \ ONNX_OPERATOR_TWO_TYPED_KERNEL_EX( \ OpName, \ @@ -29,8 +43,10 @@ namespace contrib { .TypeConstraint("Tind", DataTypeImpl::GetTensorType()), \ OpName); -REGISTER_KERNEL_TYPED(SoftmaxCrossEntropyLoss, kOnnxDomain, 12, float, int32_t) -REGISTER_KERNEL_TYPED(SoftmaxCrossEntropyLoss, kOnnxDomain, 12, float, int64_t) +REGISTER_KERNEL_VERSIONED_TYPED(SoftmaxCrossEntropyLoss, kOnnxDomain, 12, 12, float, int32_t) +REGISTER_KERNEL_VERSIONED_TYPED(SoftmaxCrossEntropyLoss, kOnnxDomain, 12, 12, float, int64_t) +REGISTER_KERNEL_TYPED(SoftmaxCrossEntropyLoss, kOnnxDomain, 13, float, int32_t) +REGISTER_KERNEL_TYPED(SoftmaxCrossEntropyLoss, kOnnxDomain, 13, float, int64_t) void GetNDCFromLogitAndLabelShape(const TensorShape& logit_shape, const TensorShape& label_shape, int64_t& N_D, int64_t& C) { // N_D = N * D1 * D2...D*K