Internal ReduceSum op that accepts axes as input (#4522)

* Initial change, to add ReduceSumTraining cpu op

* cpu support

* cuda support + more UTs

* on comments + UT

* no op support for {} axes with new attr - noop_with_empty_axes

* on comments

* fix build

* on comments

Co-authored-by: aishwarya bhandare <aibhanda@microsoft.com>
Co-authored-by: Ethan Tao <ettao@microsoft.com>
This commit is contained in:
ytaous 2020-07-20 21:05:00 -07:00 committed by GitHub
parent e92e0860c8
commit 0008e92b4e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 370 additions and 10 deletions

View file

@ -151,6 +151,24 @@ REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ArgMin, 1, 10);
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ArgMin, 11, 11);
REGISTER_UNARY_ELEMENTWISE_KERNEL(ArgMin, 12);
namespace contrib {
#define REGISTER_REDUCESUMTRAINING_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
ReduceSumTraining, \
kMSDomain, \
1, \
T, \
kCpuExecutionProvider, \
KernelDefBuilder() \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
ReduceSumTraining<T>);
REGISTER_REDUCESUMTRAINING_KERNEL_TYPED(float)
REGISTER_REDUCESUMTRAINING_KERNEL_TYPED(double)
REGISTER_REDUCESUMTRAINING_KERNEL_TYPED(int32_t)
REGISTER_REDUCESUMTRAINING_KERNEL_TYPED(int64_t)
} // namespace contrib
// When all reduce axes are located at the tail of the dims, quite general cases, transpose and extra
// copy could be skipped to improve performance. If required by check_no_transpose = true, then
// the calling code will check if the data was transposed and act accordingly.
@ -650,6 +668,43 @@ Status ReduceSum<T>::Compute(OpKernelContext* ctx) const {
return Status::OK();
}
template <typename T>
Status ReduceSumTraining<T>::Compute(OpKernelContext* ctx) const {
FastAllocVector<T> transposed_input_data(GetAllocator<T>(*ctx));
int64_t block_size;
int64_t blocks;
std::vector<int64_t> reduced_dims;
const Tensor* input = ctx->Input<Tensor>(0);
//override the attribute value with the input value for reduction_axes
const Tensor* axes_tensor = ctx->Input<Tensor>(1);
ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null");
ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1,
"An axes tensor must be a vector tensor.");
auto nDims = static_cast<size_t>(axes_tensor->Shape()[0]);
const auto* data = axes_tensor->template Data<int64_t>();
std::vector<int64_t> axes(data, data + nDims);
if (axes.size() > 0) {
ORT_ENFORCE(noop_with_empty_axes_ == false, "Noop when axes is not empty is not allowed.");
}
// empty axes and no-op
if (axes.empty() && noop_with_empty_axes_) {
auto* output = ctx->Output(0, input->Shape());
memcpy(output->template MutableData<T>(), input->template Data<T>(), input->SizeInBytes() * sizeof(T));
return Status::OK();
}
bool no_transpose = PrepareForReduce<T>(input, transposed_input_data, block_size, blocks, axes, keepdims_, reduced_dims, true);
auto* output = ctx->Output(0, reduced_dims);
ReduceSumCore(input->template Data<T>(), output->template MutableData<T>(),
no_transpose, blocks, block_size, transposed_input_data, ctx->GetOperatorThreadPool());
return Status::OK();
}
template <typename T>
Status ReduceSumSquare<T>::Compute(OpKernelContext* ctx) const {
FastAllocVector<T> transposed_input_data(GetAllocator<T>(*ctx));

View file

@ -27,12 +27,15 @@ class ReduceKernelBase {
ORT_ENFORCE(info.GetAttr("keepdims", &keepdims).IsOK());
}
keepdims_ = (keepdims == 1);
int64_t noop_with_empty_axes = info.GetAttrOrDefault<int64_t>("noop_with_empty_axes", 0);
noop_with_empty_axes_ = (noop_with_empty_axes == 1);
int64_t select_last_index = info.GetAttrOrDefault<int64_t>("select_last_index", 0);
select_last_index_ = (select_last_index != 0);
}
std::vector<int64_t> axes_;
bool keepdims_;
bool noop_with_empty_axes_;
bool select_last_index_;
};
@ -129,6 +132,21 @@ class ReduceSum final : public ReduceKernel<true> {
const TensorShape* input_shape_override = nullptr);
};
template <typename T>
class ReduceSumTraining final : public ReduceKernel<true> {
public:
ReduceSumTraining(const OpKernelInfo& info) : ReduceKernel<true>(info) {
}
Status Compute(OpKernelContext* context) const override;
// For external calls requiring ReduceSumTraining implementation - will return the reduced output.
//`input_shape_override` overrides the shape of `input` for compute purposes.
static Tensor Impl(const Tensor& input, const std::vector<int64_t>& reduce_axes,
AllocatorPtr allocator, concurrency::ThreadPool* tp, bool keep_dims,
const TensorShape* input_shape_override = nullptr);
};
template <typename T>
class ReduceSumSquare final : public ReduceKernel<true> {
public:

View file

@ -60,6 +60,18 @@ namespace cuda {
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
name<T>);
#define REGISTER_MS_KERNEL_TYPED(name, T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
name, \
kMSDomain, \
1, \
T, \
kCudaExecutionProvider, \
KernelDefBuilder() \
.InputMemoryType<OrtMemTypeCPUInput>(1) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
name<T>);
// CUDA's reduction descriptor cudnnReduceTensorDescriptor_t is a pointer so
// it's safer to wrap it with automatically memory deleter as CudnnReduceDescriptor.
// An implicit caster from CudnnReduceDescriptor to cudnnReduceTensorDescriptor_t
@ -598,12 +610,34 @@ template <typename T, cudnnReduceTensorIndices_t ReduceTensorIndices>
Status ReduceKernel<allow_multi_axes>::ComputeImpl(OpKernelContext* ctx, cudnnReduceTensorOp_t cudnn_reduce_op) const {
const Tensor* X = ctx->Input<Tensor>(0);
const std::string& op_name = this->KernelDef().OpName();
std::vector<int64_t> axes_values = axes_;
if (op_name == "ReduceSumTraining") {
//override the attribute value with the input value for reduction_axes
const Tensor* axes_tensor = ctx->Input<Tensor>(1);
ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null");
ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, "An axes tensor must be a vector tensor.");
auto nDims = static_cast<size_t>(axes_tensor->Shape()[0]);
const auto* data = axes_tensor->template Data<int64_t>();
std::vector<int64_t> axes(data, data + nDims);
axes_values = axes;
if (axes.size() > 0) {
ORT_ENFORCE(noop_with_empty_axes_ == false, "Noop when axes is not empty is not allowed.");
}
// empty axes and no-op
if (axes.empty() && noop_with_empty_axes_) {
auto* Y = ctx->Output(0, X->Shape());
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(Y->template MutableData<T>(), X->template Data<T>(), X->SizeInBytes() * sizeof(T), cudaMemcpyDeviceToDevice));
return Status::OK();
}
}
PrepareReduceMetadata prepare_reduce_metadata;
ORT_RETURN_IF_ERROR(PrepareForReduce(X,
keepdims_,
axes_,
axes_values,
prepare_reduce_metadata));
Tensor* Y = ctx->Output(0, prepare_reduce_metadata.squeezed_output_dims);
bool fast_reduction = fast_reduction_;
if (fast_reduction) {
@ -612,7 +646,7 @@ Status ReduceKernel<allow_multi_axes>::ComputeImpl(OpKernelContext* ctx, cudnnRe
fast_reduction = false;
}
return ReduceComputeCore<T, ReduceTensorIndices>(*cuda_ep_, *X, prepare_reduce_metadata, *Y, cudnn_reduce_op, axes_,
return ReduceComputeCore<T, ReduceTensorIndices>(*cuda_ep_, *X, prepare_reduce_metadata, *Y, cudnn_reduce_op, axes_values,
calculate_log_, calculate_sqt_, log_sum_exp_, fast_reduction);
}
@ -622,11 +656,34 @@ Status ReduceKernel<true>::ComputeImpl<int32_t, CUDNN_REDUCE_TENSOR_NO_INDICES>(
typedef typename ToCudaType<int32_t>::MappedType CudaT;
const Tensor* X = ctx->Input<Tensor>(0);
const std::string& op_name = this->KernelDef().OpName();
std::vector<int64_t> axes_values = axes_;
if (op_name == "ReduceSumTraining") {
//override the attribute value with the input value for reduction_axes
const Tensor* axes_tensor = ctx->Input<Tensor>(1);
ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, "An axes tensor must be a vector tensor.");
auto nDims = static_cast<size_t>(axes_tensor->Shape()[0]);
const auto* data = axes_tensor->template Data<int64_t>();
std::vector<int64_t> axes(data, data + nDims);
axes_values = axes;
if (axes.size() > 0) {
ORT_ENFORCE(noop_with_empty_axes_ == false, "Noop when axes is not empty is not allowed.");
}
// empty axes and no-op
if (axes.empty() && noop_with_empty_axes_) {
auto* Y = ctx->Output(0, X->Shape());
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(Y->template MutableData<int32_t>(), X->template Data<int32_t>(), X->SizeInBytes() * sizeof(int32_t), cudaMemcpyDeviceToDevice));
return Status::OK();
}
}
PrepareReduceMetadata prepare_reduce_metadata;
ORT_RETURN_IF_ERROR(PrepareForReduce(X,
keepdims_,
axes_,
axes_values,
prepare_reduce_metadata));
Tensor* Y = ctx->Output(0, prepare_reduce_metadata.squeezed_output_dims);
@ -925,6 +982,11 @@ REGISTER_KERNEL_TYPED_12(ReduceMin, int32_t)
REGISTER_KERNEL_TYPED_12(ReduceMin, int8_t)
REGISTER_KERNEL_TYPED_12(ReduceMin, uint8_t)
REGISTER_MS_KERNEL_TYPED(ReduceSumTraining, MLFloat16)
REGISTER_MS_KERNEL_TYPED(ReduceSumTraining, float)
REGISTER_MS_KERNEL_TYPED(ReduceSumTraining, double)
REGISTER_MS_KERNEL_TYPED(ReduceSumTraining, int32_t)
REGISTER_KERNEL_HFD(ReduceProd)
REGISTER_KERNEL_HFD(ReduceSum)
REGISTER_KERNEL_HFD(ReduceLogSum)

View file

@ -70,6 +70,7 @@ class ReduceKernel : public CudaKernel, public ReduceKernelBase<allow_multi_axes
using ReduceKernelBase<allow_multi_axes>::axes_;
using ReduceKernelBase<allow_multi_axes>::keepdims_;
using ReduceKernelBase<allow_multi_axes>::noop_with_empty_axes_;
bool calculate_log_;
bool calculate_sqt_;
@ -174,6 +175,18 @@ class ReduceSum final : public ReduceKernel<true> {
}
};
template <typename T>
class ReduceSumTraining final : public ReduceKernel<true> {
public:
ReduceSumTraining(const OpKernelInfo& info) : ReduceKernel<true>(info) {
fast_reduction_ = true;
}
Status ComputeInternal(OpKernelContext* ctx) const override {
return ComputeImpl<T>(ctx, CUDNN_REDUCE_TENSOR_ADD);
}
};
template <typename T>
class ReduceLogSum final : public ReduceKernel<true> {
public:

View file

@ -3,6 +3,7 @@
#include "core/graph/op.h"
#include "core/graph/contrib_ops/contrib_defs.h"
#include "core/providers/common.h"
#include "orttraining/core/graph/training_op_defs.h"
#include "onnx/defs/function.h"
#include <math.h>
@ -1026,7 +1027,7 @@ Example 4:
"the case during training.",
"T1",
OpSchema::Optional)
.Input(4, "training_mode",
.Input(4, "training_mode",
"If set to true then it indicates dropout is being used for "
"training. It is an optional value hence unless specified explicitly, it is false. "
"If it is false, ratio is ignored and the operation mimics inference mode where nothing "
@ -1058,6 +1059,76 @@ Example 4:
}
});
ONNX_CONTRIB_OPERATOR_SCHEMA(ReduceSumTraining)
.SetDomain(kMSDomain)
.SinceVersion(1)
.SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL)
.SetDoc("ReduceSumTraining")
.Attr("keepdims",
"Keep the reduced dimension or not, default 1 mean keep reduced dimension.",
AttributeProto::INT,
static_cast<int64_t>(1))
.Attr("noop_with_empty_axes",
"Perform reduction or not when axes is empty, default false mean perform reduction."
"when axes is empty and this attribute is set to true, input tensor will not be reduced,"
"thus output tensor would be equivalent to input tensor.",
AttributeProto::INT,
static_cast<int64_t>(0))
.AllowUncheckedAttributes()
.Input(0, "data", "An input tensor.", "T")
.Input(1, "axes",
"A list of integers, along which to reduce. The default is to reduce over "
"all the dimensions of the input tensor. Accepted range is [-r, r-1] where r = rank(data).",
"tensor(int64)")
.Output(0, "reduced", "Reduced output tensor.", "T")
.TypeConstraint(
"T",
OpSchema::numeric_types_for_math_reduction(),
"Constrain input and output types to high-precision numeric tensors.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
if (!hasNInputShapes(ctx, 1)) {
return;
}
// skip if axes is not an initializer
auto axes_proto = ctx.getInputData(1);
if (axes_proto == nullptr) {
return;
}
int64_t keep_dims = 1;
auto attr_proto = ctx.getAttribute("keepdims");
if (attr_proto) {
keep_dims = attr_proto->i();
}
auto& input_shape = ctx.getInputType(0)->tensor_type().shape();
int64_t input_ndim = input_shape.dim_size();
auto output_shape =
ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
std::vector<int64_t> axes_values = ParseData<int64_t>(axes_proto);
std::vector<int64_t> axes;
axes.reserve(axes_values.size());
for (int64_t axis : axes_values) {
axes.push_back(HandleNegativeAxis(axis, input_ndim));
}
for (int i = 0; i < input_ndim; ++i) {
// axes empty means reduce all dim
if (!axes.empty() &&
std::find(axes.begin(), axes.end(), i) == axes.end()) {
auto dim = output_shape->add_dim();
dim->CopyFrom(input_shape.dim(i));
} else {
if (keep_dims == 1) {
auto dim = output_shape->add_dim();
dim->set_dim_value(1);
}
}
}
});
ONNX_CONTRIB_OPERATOR_SCHEMA(DropoutGrad)
.SetDomain(kMSDomain)
.SinceVersion(1)
@ -1074,11 +1145,11 @@ Example 4:
"T1",
OpSchema::Optional)
.Input(3, "training_mode",
"If set to true then it indicates dropout is being used for training. It is an optional value hence unless "
"specified explicitly, it is false. If it is false, ratio is ignored and the operation mimics inference mode where "
"nothing will be dropped from the input data and if mask is requested as output it will contain all ones.",
"T2",
OpSchema::Optional)
"If set to true then it indicates dropout is being used for training. It is an optional value hence unless "
"specified explicitly, it is false. If it is false, ratio is ignored and the operation mimics inference mode where "
"nothing will be dropped from the input data and if mask is requested as output it will contain all ones.",
"T2",
OpSchema::Optional)
.Output(0, "dx", "Gradient of the input.", "T")
.TypeConstraint(
"T",

View file

@ -214,5 +214,130 @@ TEST(ReductionOpTest, ReduceAllL2Many) {
#endif
TEST(ReductionOpTest, ReduceSumTraining_int32) {
OpTester test("ReduceSumTraining", 1, onnxruntime::kMSDomain);
test.AddAttribute("keepdims", (int64_t)1);
test.AddInput<int32_t>("data", {3, 2, 2},
{1, 2,
3, 4,
5, 6,
7, 8,
9, 10,
11, 12});
test.AddInput<int64_t>("axes", {2}, {0, 2}, true /*is_initializer*/);
test.AddOutput<int32_t>("reduced", {1, 2, 1}, {33, 45});
test.Run();
}
TEST(ReductionOpTest, ReduceSumTraining_default_axes_keepdims) {
OpTester test("ReduceSumTraining", 1, onnxruntime::kMSDomain);
test.AddAttribute("keepdims", (int64_t)1);
test.AddInput<float>("data", {3, 2, 2},
{1.0f, 2.0f,
3.0f, 4.0f,
5.0f, 6.0f,
7.0f, 8.0f,
9.0f, 10.0f,
11.0f, 12.0f});
test.AddInput<int64_t>("axes", {0}, {}, true /*is_initializer*/);
test.AddOutput<float>("reduced", {1, 1, 1}, {78.0f});
test.Run();
}
TEST(ReductionOpTest, ReduceSumTraining_axes_not_initializer) {
OpTester test("ReduceSumTraining", 1, onnxruntime::kMSDomain);
test.AddAttribute("keepdims", (int64_t)1);
test.AddInput<float>("data", {3, 2, 2},
{1.0f, 2.0f,
3.0f, 4.0f,
5.0f, 6.0f,
7.0f, 8.0f,
9.0f, 10.0f,
11.0f, 12.0f});
test.AddInput<int64_t>("axes", {0}, {});
test.AddOutput<float>("reduced", {1, 1, 1}, {78.0f});
test.Run();
}
TEST(ReductionOpTest, ReduceSumTraining_empty_axes_noop) {
OpTester test("ReduceSumTraining", 1, onnxruntime::kMSDomain);
test.AddAttribute("keepdims", (int64_t)1);
test.AddAttribute("noop_with_empty_axes", (int64_t)1);
test.AddInput<float>("data", {3, 2, 2},
{1.0f, 2.0f,
3.0f, 4.0f,
5.0f, 6.0f,
7.0f, 8.0f,
9.0f, 10.0f,
11.0f, 12.0f});
test.AddInput<int64_t>("axes", {0}, {}, true /*is_initializer*/);
test.AddOutput<float>("reduced", {3, 2, 2},
{1.0f, 2.0f,
3.0f, 4.0f,
5.0f, 6.0f,
7.0f, 8.0f,
9.0f, 10.0f,
11.0f, 12.0f});
test.Run();
}
TEST(ReductionOpTest, ReduceSumTraining_do_not_keepdims) {
OpTester test("ReduceSumTraining", 1, onnxruntime::kMSDomain);
test.AddAttribute("keepdims", (int64_t)0);
test.AddInput<float>("data", {1, 2, 2},
{1.0f, 2.0f,
3.0f, 4.0f});
test.AddInput<int64_t>("axes", {1}, {1}, true /*is_initializer*/);
test.AddOutput<float>("reduced", {1, 2}, {4.0f, 6.0f});
test.Run();
}
TEST(ReductionOpTest, ReduceSumTraining_neg_axis) {
OpTester test("ReduceSumTraining", 1, onnxruntime::kMSDomain);
test.AddAttribute("keepdims", (int64_t)0);
test.AddInput<float>("data", {1, 2, 2},
{1.0f, 2.0f,
3.0f, 4.0f});
test.AddInput<int64_t>("axes", {1}, {-2}, true /*is_initializer*/);
test.AddOutput<float>("reduced", {1, 2}, {4.0f, 6.0f});
test.Run();
}
#ifdef USE_CUDA
TEST(ReductionOpTest, ReduceSumTrainingHalfHalf) {
OpTester test("ReduceSumTraining", 1, onnxruntime::kMSDomain);
test.AddAttribute("keepdims", (int64_t)0);
std::vector<float> data = {1.0f, 2.0f,
3.0f, 4.0f,
5.0f, 6.0f,
7.0f, 8.0f,
9.0f, 10.0f,
11.0f, 12.0f};
std::vector<MLFloat16> data_half(12);
ConvertFloatToMLFloat16(data.data(), data_half.data(), 12);
std::vector<float> result = {36.0f, 42.0f};
std::vector<MLFloat16> result_half(2);
ConvertFloatToMLFloat16(result.data(), result_half.data(), 2);
test.AddInput<MLFloat16>("data", {3, 2, 2}, data_half);
test.AddInput<int64_t>("axes", {2}, {0, 1}, true /*is_initializer*/);
test.AddOutput<MLFloat16>("reduced", {2}, result_half);
test.Run();
}
#endif
} // namespace test
} // namespace onnxruntime

View file

@ -16,6 +16,10 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AdamO
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, InPlaceAccumulator);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ZeroGradient);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Group);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ReduceSumTraining);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, ReduceSumTraining);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int32_t, ReduceSumTraining);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int64_t, ReduceSumTraining);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SoftmaxCrossEntropy);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SoftmaxCrossEntropyGrad);
@ -75,6 +79,10 @@ Status RegisterCpuTrainingKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, InPlaceAccumulator)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ZeroGradient)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Group)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ReduceSumTraining)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, ReduceSumTraining)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int32_t, ReduceSumTraining)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int64_t, ReduceSumTraining)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SoftmaxCrossEntropy)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SoftmaxCrossEntropyGrad)>,

View file

@ -12,6 +12,10 @@ namespace cuda {
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, View);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Group);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, SGDOptimizer);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ReduceSumTraining);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, ReduceSumTraining);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int32_t, ReduceSumTraining);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ReduceSumTraining);
// Adam
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_int64_t_float_float_float_float, AdamOptimizer);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_int64_t_float_MLFloat16_float_float, AdamOptimizer);
@ -122,6 +126,10 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, View)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Group)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, SGDOptimizer)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ReduceSumTraining)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, ReduceSumTraining)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int32_t, ReduceSumTraining)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ReduceSumTraining)>,
// Adam
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_int64_t_float_float_float_float, AdamOptimizer)>,