From 0008e92b4e1d7f1068d8638dfe5decaa5bfa5c4e Mon Sep 17 00:00:00 2001 From: ytaous <4484531+ytaous@users.noreply.github.com> Date: Mon, 20 Jul 2020 21:05:00 -0700 Subject: [PATCH] Internal ReduceSum op that accepts axes as input (#4522) * Initial change, to add ReduceSumTraining cpu op * cpu support * cuda support + more UTs * on comments + UT * no op support for {} axes with new attr - noop_with_empty_axes * on comments * fix build * on comments Co-authored-by: aishwarya bhandare Co-authored-by: Ethan Tao --- .../providers/cpu/reduction/reduction_ops.cc | 55 ++++++++ .../providers/cpu/reduction/reduction_ops.h | 18 +++ .../providers/cuda/reduction/reduction_ops.cc | 70 +++++++++- .../providers/cuda/reduction/reduction_ops.h | 13 ++ .../core/graph/training_op_defs.cc | 83 +++++++++++- .../cpu/reduction/reduction_ops_test.cc | 125 ++++++++++++++++++ .../training_ops/cpu/cpu_training_kernels.cc | 8 ++ .../cuda/cuda_training_kernels.cc | 8 ++ 8 files changed, 370 insertions(+), 10 deletions(-) diff --git a/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc b/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc index b37872b8ae..fe797a6beb 100644 --- a/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc @@ -151,6 +151,24 @@ REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ArgMin, 1, 10); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ArgMin, 11, 11); REGISTER_UNARY_ELEMENTWISE_KERNEL(ArgMin, 12); +namespace contrib { +#define REGISTER_REDUCESUMTRAINING_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + ReduceSumTraining, \ + kMSDomain, \ + 1, \ + T, \ + kCpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + ReduceSumTraining); + +REGISTER_REDUCESUMTRAINING_KERNEL_TYPED(float) +REGISTER_REDUCESUMTRAINING_KERNEL_TYPED(double) +REGISTER_REDUCESUMTRAINING_KERNEL_TYPED(int32_t) +REGISTER_REDUCESUMTRAINING_KERNEL_TYPED(int64_t) +} // namespace contrib + // When all reduce axes are located at the tail of the dims, quite general cases, transpose and extra // copy could be skipped to improve performance. If required by check_no_transpose = true, then // the calling code will check if the data was transposed and act accordingly. @@ -650,6 +668,43 @@ Status ReduceSum::Compute(OpKernelContext* ctx) const { return Status::OK(); } +template +Status ReduceSumTraining::Compute(OpKernelContext* ctx) const { + FastAllocVector transposed_input_data(GetAllocator(*ctx)); + int64_t block_size; + int64_t blocks; + std::vector reduced_dims; + const Tensor* input = ctx->Input(0); + + //override the attribute value with the input value for reduction_axes + const Tensor* axes_tensor = ctx->Input(1); + ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null"); + ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, + "An axes tensor must be a vector tensor."); + auto nDims = static_cast(axes_tensor->Shape()[0]); + const auto* data = axes_tensor->template Data(); + std::vector axes(data, data + nDims); + if (axes.size() > 0) { + ORT_ENFORCE(noop_with_empty_axes_ == false, "Noop when axes is not empty is not allowed."); + } + + // empty axes and no-op + if (axes.empty() && noop_with_empty_axes_) { + auto* output = ctx->Output(0, input->Shape()); + memcpy(output->template MutableData(), input->template Data(), input->SizeInBytes() * sizeof(T)); + return Status::OK(); + } + + bool no_transpose = PrepareForReduce(input, transposed_input_data, block_size, blocks, axes, keepdims_, reduced_dims, true); + + auto* output = ctx->Output(0, reduced_dims); + + ReduceSumCore(input->template Data(), output->template MutableData(), + no_transpose, blocks, block_size, transposed_input_data, ctx->GetOperatorThreadPool()); + + return Status::OK(); +} + template Status ReduceSumSquare::Compute(OpKernelContext* ctx) const { FastAllocVector transposed_input_data(GetAllocator(*ctx)); diff --git a/onnxruntime/core/providers/cpu/reduction/reduction_ops.h b/onnxruntime/core/providers/cpu/reduction/reduction_ops.h index ea8b17fc41..a6dbff884a 100644 --- a/onnxruntime/core/providers/cpu/reduction/reduction_ops.h +++ b/onnxruntime/core/providers/cpu/reduction/reduction_ops.h @@ -27,12 +27,15 @@ class ReduceKernelBase { ORT_ENFORCE(info.GetAttr("keepdims", &keepdims).IsOK()); } keepdims_ = (keepdims == 1); + int64_t noop_with_empty_axes = info.GetAttrOrDefault("noop_with_empty_axes", 0); + noop_with_empty_axes_ = (noop_with_empty_axes == 1); int64_t select_last_index = info.GetAttrOrDefault("select_last_index", 0); select_last_index_ = (select_last_index != 0); } std::vector axes_; bool keepdims_; + bool noop_with_empty_axes_; bool select_last_index_; }; @@ -129,6 +132,21 @@ class ReduceSum final : public ReduceKernel { const TensorShape* input_shape_override = nullptr); }; +template +class ReduceSumTraining final : public ReduceKernel { + public: + ReduceSumTraining(const OpKernelInfo& info) : ReduceKernel(info) { + } + + Status Compute(OpKernelContext* context) const override; + + // For external calls requiring ReduceSumTraining implementation - will return the reduced output. + //`input_shape_override` overrides the shape of `input` for compute purposes. + static Tensor Impl(const Tensor& input, const std::vector& reduce_axes, + AllocatorPtr allocator, concurrency::ThreadPool* tp, bool keep_dims, + const TensorShape* input_shape_override = nullptr); +}; + template class ReduceSumSquare final : public ReduceKernel { public: diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc index 5dad097899..592dc1ed51 100644 --- a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc @@ -60,6 +60,18 @@ namespace cuda { KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ name); +#define REGISTER_MS_KERNEL_TYPED(name, T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + name, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + KernelDefBuilder() \ + .InputMemoryType(1) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + name); + // CUDA's reduction descriptor cudnnReduceTensorDescriptor_t is a pointer so // it's safer to wrap it with automatically memory deleter as CudnnReduceDescriptor. // An implicit caster from CudnnReduceDescriptor to cudnnReduceTensorDescriptor_t @@ -598,12 +610,34 @@ template Status ReduceKernel::ComputeImpl(OpKernelContext* ctx, cudnnReduceTensorOp_t cudnn_reduce_op) const { const Tensor* X = ctx->Input(0); + const std::string& op_name = this->KernelDef().OpName(); + std::vector axes_values = axes_; + if (op_name == "ReduceSumTraining") { + //override the attribute value with the input value for reduction_axes + const Tensor* axes_tensor = ctx->Input(1); + ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null"); + ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, "An axes tensor must be a vector tensor."); + auto nDims = static_cast(axes_tensor->Shape()[0]); + const auto* data = axes_tensor->template Data(); + std::vector axes(data, data + nDims); + axes_values = axes; + if (axes.size() > 0) { + ORT_ENFORCE(noop_with_empty_axes_ == false, "Noop when axes is not empty is not allowed."); + } + + // empty axes and no-op + if (axes.empty() && noop_with_empty_axes_) { + auto* Y = ctx->Output(0, X->Shape()); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(Y->template MutableData(), X->template Data(), X->SizeInBytes() * sizeof(T), cudaMemcpyDeviceToDevice)); + return Status::OK(); + } + } + PrepareReduceMetadata prepare_reduce_metadata; ORT_RETURN_IF_ERROR(PrepareForReduce(X, keepdims_, - axes_, + axes_values, prepare_reduce_metadata)); - Tensor* Y = ctx->Output(0, prepare_reduce_metadata.squeezed_output_dims); bool fast_reduction = fast_reduction_; if (fast_reduction) { @@ -612,7 +646,7 @@ Status ReduceKernel::ComputeImpl(OpKernelContext* ctx, cudnnRe fast_reduction = false; } - return ReduceComputeCore(*cuda_ep_, *X, prepare_reduce_metadata, *Y, cudnn_reduce_op, axes_, + return ReduceComputeCore(*cuda_ep_, *X, prepare_reduce_metadata, *Y, cudnn_reduce_op, axes_values, calculate_log_, calculate_sqt_, log_sum_exp_, fast_reduction); } @@ -622,11 +656,34 @@ Status ReduceKernel::ComputeImpl( typedef typename ToCudaType::MappedType CudaT; const Tensor* X = ctx->Input(0); + + const std::string& op_name = this->KernelDef().OpName(); + std::vector axes_values = axes_; + if (op_name == "ReduceSumTraining") { + //override the attribute value with the input value for reduction_axes + const Tensor* axes_tensor = ctx->Input(1); + ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, "An axes tensor must be a vector tensor."); + auto nDims = static_cast(axes_tensor->Shape()[0]); + const auto* data = axes_tensor->template Data(); + std::vector axes(data, data + nDims); + axes_values = axes; + if (axes.size() > 0) { + ORT_ENFORCE(noop_with_empty_axes_ == false, "Noop when axes is not empty is not allowed."); + } + + // empty axes and no-op + if (axes.empty() && noop_with_empty_axes_) { + auto* Y = ctx->Output(0, X->Shape()); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(Y->template MutableData(), X->template Data(), X->SizeInBytes() * sizeof(int32_t), cudaMemcpyDeviceToDevice)); + return Status::OK(); + } + } + PrepareReduceMetadata prepare_reduce_metadata; ORT_RETURN_IF_ERROR(PrepareForReduce(X, keepdims_, - axes_, + axes_values, prepare_reduce_metadata)); Tensor* Y = ctx->Output(0, prepare_reduce_metadata.squeezed_output_dims); @@ -925,6 +982,11 @@ REGISTER_KERNEL_TYPED_12(ReduceMin, int32_t) REGISTER_KERNEL_TYPED_12(ReduceMin, int8_t) REGISTER_KERNEL_TYPED_12(ReduceMin, uint8_t) +REGISTER_MS_KERNEL_TYPED(ReduceSumTraining, MLFloat16) +REGISTER_MS_KERNEL_TYPED(ReduceSumTraining, float) +REGISTER_MS_KERNEL_TYPED(ReduceSumTraining, double) +REGISTER_MS_KERNEL_TYPED(ReduceSumTraining, int32_t) + REGISTER_KERNEL_HFD(ReduceProd) REGISTER_KERNEL_HFD(ReduceSum) REGISTER_KERNEL_HFD(ReduceLogSum) diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_ops.h b/onnxruntime/core/providers/cuda/reduction/reduction_ops.h index c0e2b10bec..8cc8111f25 100644 --- a/onnxruntime/core/providers/cuda/reduction/reduction_ops.h +++ b/onnxruntime/core/providers/cuda/reduction/reduction_ops.h @@ -70,6 +70,7 @@ class ReduceKernel : public CudaKernel, public ReduceKernelBase::axes_; using ReduceKernelBase::keepdims_; + using ReduceKernelBase::noop_with_empty_axes_; bool calculate_log_; bool calculate_sqt_; @@ -174,6 +175,18 @@ class ReduceSum final : public ReduceKernel { } }; +template +class ReduceSumTraining final : public ReduceKernel { + public: + ReduceSumTraining(const OpKernelInfo& info) : ReduceKernel(info) { + fast_reduction_ = true; + } + + Status ComputeInternal(OpKernelContext* ctx) const override { + return ComputeImpl(ctx, CUDNN_REDUCE_TENSOR_ADD); + } +}; + template class ReduceLogSum final : public ReduceKernel { public: diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index 7b4478cd38..2820b09b57 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -3,6 +3,7 @@ #include "core/graph/op.h" #include "core/graph/contrib_ops/contrib_defs.h" +#include "core/providers/common.h" #include "orttraining/core/graph/training_op_defs.h" #include "onnx/defs/function.h" #include @@ -1026,7 +1027,7 @@ Example 4: "the case during training.", "T1", OpSchema::Optional) - .Input(4, "training_mode", + .Input(4, "training_mode", "If set to true then it indicates dropout is being used for " "training. It is an optional value hence unless specified explicitly, it is false. " "If it is false, ratio is ignored and the operation mimics inference mode where nothing " @@ -1058,6 +1059,76 @@ Example 4: } }); + ONNX_CONTRIB_OPERATOR_SCHEMA(ReduceSumTraining) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL) + .SetDoc("ReduceSumTraining") + .Attr("keepdims", + "Keep the reduced dimension or not, default 1 mean keep reduced dimension.", + AttributeProto::INT, + static_cast(1)) + .Attr("noop_with_empty_axes", + "Perform reduction or not when axes is empty, default false mean perform reduction." + "when axes is empty and this attribute is set to true, input tensor will not be reduced," + "thus output tensor would be equivalent to input tensor.", + AttributeProto::INT, + static_cast(0)) + .AllowUncheckedAttributes() + .Input(0, "data", "An input tensor.", "T") + .Input(1, "axes", + "A list of integers, along which to reduce. The default is to reduce over " + "all the dimensions of the input tensor. Accepted range is [-r, r-1] where r = rank(data).", + "tensor(int64)") + .Output(0, "reduced", "Reduced output tensor.", "T") + .TypeConstraint( + "T", + OpSchema::numeric_types_for_math_reduction(), + "Constrain input and output types to high-precision numeric tensors.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + if (!hasNInputShapes(ctx, 1)) { + return; + } + + // skip if axes is not an initializer + auto axes_proto = ctx.getInputData(1); + if (axes_proto == nullptr) { + return; + } + + int64_t keep_dims = 1; + auto attr_proto = ctx.getAttribute("keepdims"); + if (attr_proto) { + keep_dims = attr_proto->i(); + } + auto& input_shape = ctx.getInputType(0)->tensor_type().shape(); + int64_t input_ndim = input_shape.dim_size(); + auto output_shape = + ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape(); + + std::vector axes_values = ParseData(axes_proto); + std::vector axes; + axes.reserve(axes_values.size()); + for (int64_t axis : axes_values) { + axes.push_back(HandleNegativeAxis(axis, input_ndim)); + } + + for (int i = 0; i < input_ndim; ++i) { + // axes empty means reduce all dim + if (!axes.empty() && + std::find(axes.begin(), axes.end(), i) == axes.end()) { + auto dim = output_shape->add_dim(); + dim->CopyFrom(input_shape.dim(i)); + } else { + if (keep_dims == 1) { + auto dim = output_shape->add_dim(); + dim->set_dim_value(1); + } + } + } + }); + ONNX_CONTRIB_OPERATOR_SCHEMA(DropoutGrad) .SetDomain(kMSDomain) .SinceVersion(1) @@ -1074,11 +1145,11 @@ Example 4: "T1", OpSchema::Optional) .Input(3, "training_mode", - "If set to true then it indicates dropout is being used for training. It is an optional value hence unless " - "specified explicitly, it is false. If it is false, ratio is ignored and the operation mimics inference mode where " - "nothing will be dropped from the input data and if mask is requested as output it will contain all ones.", - "T2", - OpSchema::Optional) + "If set to true then it indicates dropout is being used for training. It is an optional value hence unless " + "specified explicitly, it is false. If it is false, ratio is ignored and the operation mimics inference mode where " + "nothing will be dropped from the input data and if mask is requested as output it will contain all ones.", + "T2", + OpSchema::Optional) .Output(0, "dx", "Gradient of the input.", "T") .TypeConstraint( "T", diff --git a/orttraining/orttraining/test/training_ops/cpu/reduction/reduction_ops_test.cc b/orttraining/orttraining/test/training_ops/cpu/reduction/reduction_ops_test.cc index eae779b38f..943dbe77a0 100644 --- a/orttraining/orttraining/test/training_ops/cpu/reduction/reduction_ops_test.cc +++ b/orttraining/orttraining/test/training_ops/cpu/reduction/reduction_ops_test.cc @@ -214,5 +214,130 @@ TEST(ReductionOpTest, ReduceAllL2Many) { #endif +TEST(ReductionOpTest, ReduceSumTraining_int32) { + OpTester test("ReduceSumTraining", 1, onnxruntime::kMSDomain); + test.AddAttribute("keepdims", (int64_t)1); + test.AddInput("data", {3, 2, 2}, + {1, 2, + 3, 4, + + 5, 6, + 7, 8, + + 9, 10, + 11, 12}); + test.AddInput("axes", {2}, {0, 2}, true /*is_initializer*/); + test.AddOutput("reduced", {1, 2, 1}, {33, 45}); + test.Run(); +} + +TEST(ReductionOpTest, ReduceSumTraining_default_axes_keepdims) { + OpTester test("ReduceSumTraining", 1, onnxruntime::kMSDomain); + test.AddAttribute("keepdims", (int64_t)1); + test.AddInput("data", {3, 2, 2}, + {1.0f, 2.0f, + 3.0f, 4.0f, + + 5.0f, 6.0f, + 7.0f, 8.0f, + + 9.0f, 10.0f, + 11.0f, 12.0f}); + test.AddInput("axes", {0}, {}, true /*is_initializer*/); + test.AddOutput("reduced", {1, 1, 1}, {78.0f}); + test.Run(); +} + +TEST(ReductionOpTest, ReduceSumTraining_axes_not_initializer) { + OpTester test("ReduceSumTraining", 1, onnxruntime::kMSDomain); + test.AddAttribute("keepdims", (int64_t)1); + test.AddInput("data", {3, 2, 2}, + {1.0f, 2.0f, + 3.0f, 4.0f, + + 5.0f, 6.0f, + 7.0f, 8.0f, + + 9.0f, 10.0f, + 11.0f, 12.0f}); + test.AddInput("axes", {0}, {}); + test.AddOutput("reduced", {1, 1, 1}, {78.0f}); + test.Run(); +} + +TEST(ReductionOpTest, ReduceSumTraining_empty_axes_noop) { + OpTester test("ReduceSumTraining", 1, onnxruntime::kMSDomain); + test.AddAttribute("keepdims", (int64_t)1); + test.AddAttribute("noop_with_empty_axes", (int64_t)1); + test.AddInput("data", {3, 2, 2}, + {1.0f, 2.0f, + 3.0f, 4.0f, + + 5.0f, 6.0f, + 7.0f, 8.0f, + + 9.0f, 10.0f, + 11.0f, 12.0f}); + test.AddInput("axes", {0}, {}, true /*is_initializer*/); + test.AddOutput("reduced", {3, 2, 2}, + {1.0f, 2.0f, + 3.0f, 4.0f, + + 5.0f, 6.0f, + 7.0f, 8.0f, + + 9.0f, 10.0f, + 11.0f, 12.0f}); + test.Run(); +} + +TEST(ReductionOpTest, ReduceSumTraining_do_not_keepdims) { + OpTester test("ReduceSumTraining", 1, onnxruntime::kMSDomain); + test.AddAttribute("keepdims", (int64_t)0); + test.AddInput("data", {1, 2, 2}, + {1.0f, 2.0f, + 3.0f, 4.0f}); + test.AddInput("axes", {1}, {1}, true /*is_initializer*/); + test.AddOutput("reduced", {1, 2}, {4.0f, 6.0f}); + test.Run(); +} + +TEST(ReductionOpTest, ReduceSumTraining_neg_axis) { + OpTester test("ReduceSumTraining", 1, onnxruntime::kMSDomain); + test.AddAttribute("keepdims", (int64_t)0); + test.AddInput("data", {1, 2, 2}, + {1.0f, 2.0f, + 3.0f, 4.0f}); + test.AddInput("axes", {1}, {-2}, true /*is_initializer*/); + test.AddOutput("reduced", {1, 2}, {4.0f, 6.0f}); + test.Run(); +} + +#ifdef USE_CUDA +TEST(ReductionOpTest, ReduceSumTrainingHalfHalf) { + OpTester test("ReduceSumTraining", 1, onnxruntime::kMSDomain); + test.AddAttribute("keepdims", (int64_t)0); + + std::vector data = {1.0f, 2.0f, + 3.0f, 4.0f, + + 5.0f, 6.0f, + 7.0f, 8.0f, + + 9.0f, 10.0f, + 11.0f, 12.0f}; + std::vector data_half(12); + ConvertFloatToMLFloat16(data.data(), data_half.data(), 12); + + std::vector result = {36.0f, 42.0f}; + std::vector result_half(2); + ConvertFloatToMLFloat16(result.data(), result_half.data(), 2); + test.AddInput("data", {3, 2, 2}, data_half); + test.AddInput("axes", {2}, {0, 1}, true /*is_initializer*/); + test.AddOutput("reduced", {2}, result_half); + test.Run(); +} +#endif + } // namespace test } // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc b/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc index 04f3bf8428..b974d35027 100644 --- a/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc @@ -16,6 +16,10 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AdamO class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, InPlaceAccumulator); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ZeroGradient); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Group); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ReduceSumTraining); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, ReduceSumTraining); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int32_t, ReduceSumTraining); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int64_t, ReduceSumTraining); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SoftmaxCrossEntropy); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SoftmaxCrossEntropyGrad); @@ -75,6 +79,10 @@ Status RegisterCpuTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc index 7ee7e79756..30b987072e 100644 --- a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc @@ -12,6 +12,10 @@ namespace cuda { class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, View); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Group); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, SGDOptimizer); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ReduceSumTraining); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, ReduceSumTraining); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int32_t, ReduceSumTraining); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ReduceSumTraining); // Adam class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_int64_t_float_float_float_float, AdamOptimizer); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_int64_t_float_MLFloat16_float_float, AdamOptimizer); @@ -122,6 +126,10 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // Adam BuildKernelCreateInfo,