From bfbcc89db1adc8689b1b9dde18e5d2872ea40c87 Mon Sep 17 00:00:00 2001 From: ashbhandare Date: Fri, 14 May 2021 09:00:27 -0700 Subject: [PATCH] Add MLFloat16 support for SoftmaxCrossEntropyLoss for CUDA EP (#7679) * Forward op changes * Add tests, improve kernel * add opset 13 registration, remove unnecessary changes * Add fp16 grad for SCELoss, review comments --- .../cuda/reduction/reduction_functions.cu | 1 + .../providers/compare_provider_test_utils.cc | 35 +++- .../providers/compare_provider_test_utils.h | 3 +- onnxruntime/test/util/compare_ortvalue.cc | 10 +- .../training_ops/cuda/cross_entropy_test.cc | 162 ++++++++++++++++-- .../cuda/cuda_training_kernels.cc | 6 + .../loss/softmax_cross_entropy_loss_impl.cc | 70 +++++--- .../loss/softmax_cross_entropy_loss_impl.cu | 125 ++++++++------ .../loss/softmax_cross_entropy_loss_impl.h | 8 +- .../rocm/rocm_training_kernels.cc | 6 + 10 files changed, 316 insertions(+), 110 deletions(-) diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_functions.cu b/onnxruntime/core/providers/cuda/reduction/reduction_functions.cu index 6ac4e64900..c83060ef48 100644 --- a/onnxruntime/core/providers/cuda/reduction/reduction_functions.cu +++ b/onnxruntime/core/providers/cuda/reduction/reduction_functions.cu @@ -343,6 +343,7 @@ Status reduce_mean( #define INSTANTIATE_REDUCE_SUM(TIn, TOut) \ template Status reduce_sum(cudaStream_t stream, const TIn* input, TOut* output, int size, void* buffer, size_t buffer_size) +INSTANTIATE_REDUCE_SUM(half, half); INSTANTIATE_REDUCE_SUM(half, float); INSTANTIATE_REDUCE_SUM(float, float); INSTANTIATE_REDUCE_SUM(double, double); diff --git a/onnxruntime/test/providers/compare_provider_test_utils.cc b/onnxruntime/test/providers/compare_provider_test_utils.cc index 53082fcbbb..b9b101ed72 100644 --- a/onnxruntime/test/providers/compare_provider_test_utils.cc +++ b/onnxruntime/test/providers/compare_provider_test_utils.cc @@ -2,9 +2,10 @@ // Licensed under the MIT License. #include "core/session/inference_session.h" +#include "core/optimizer/insert_cast_transformer.h" #include "test/util/include/default_providers.h" #include "test/providers/compare_provider_test_utils.h" - +#include "test/test_environment.h" #include "test/compare_ortvalue.h" using namespace std; @@ -41,7 +42,8 @@ std::unique_ptr GetExecutionProvider(const std::string& prov void CompareOpTester::CompareWithCPU(const std::string& target_provider_type, double per_sample_tolerance, - double relative_per_sample_tolerance) { + double relative_per_sample_tolerance, + const bool need_cpu_cast) { #ifndef NDEBUG run_called_ = true; #endif @@ -52,7 +54,21 @@ void CompareOpTester::CompareWithCPU(const std::string& target_provider_type, auto p_model = BuildGraph(); auto& graph = p_model->MainGraph(); - Status status = graph.Resolve(); + Status status; + + // In InferenceSession::Initialize(), the call to graph partitioner, which is responsible + // for Inlining function bodies for ops whose kernel is missing happens before the + // Cast Transformer. As a result, for MLFloat16 tests where the node is missing a CPU kernel, + // the function body is instead used for CPU pass. This option allows the comparison with + // the CPU kernel by adding the input/output casts before looking for a registered CPU kernel. + if (need_cpu_cast) { + InsertCastTransformer transformer("Test"); + bool modified = false; + status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()); + ASSERT_TRUE(status.IsOK()); + } + + status = graph.Resolve(); ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); if (!status.IsOK()) { return; @@ -102,11 +118,22 @@ void CompareOpTester::CompareWithCPU(const std::string& target_provider_type, } // run with target provider + // build the graph again as the cpu graph may be with casts + auto p_tp_model = BuildGraph(); + auto& tp_graph = p_tp_model->MainGraph(); + + status = tp_graph.Resolve(); + ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); + if (!status.IsOK()) { + return; + } InferenceSession target_session_object{so, GetEnvironment()}; EXPECT_TRUE(target_session_object.RegisterExecutionProvider(std::move(target_execution_provider)).IsOK()); - std::istringstream model_proto_str1(s1); + std::string s2; + p_tp_model->ToProto().SerializeToString(&s2); + std::istringstream model_proto_str1(s2); status = target_session_object.Load(model_proto_str1); EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); if (!status.IsOK()) { diff --git a/onnxruntime/test/providers/compare_provider_test_utils.h b/onnxruntime/test/providers/compare_provider_test_utils.h index dc9e1b43f7..88eb891df0 100644 --- a/onnxruntime/test/providers/compare_provider_test_utils.h +++ b/onnxruntime/test/providers/compare_provider_test_utils.h @@ -19,7 +19,8 @@ class CompareOpTester : public OpTester { void CompareWithCPU(const std::string& target_provider_type, double per_sample_tolerance = 1e-4, - double relative_per_sample_tolerance = 1e-4); + double relative_per_sample_tolerance = 1e-4, + const bool need_cpu_cast = false); }; } // namespace test diff --git a/onnxruntime/test/util/compare_ortvalue.cc b/onnxruntime/test/util/compare_ortvalue.cc index b69f5159bc..b2c3560d5c 100644 --- a/onnxruntime/test/util/compare_ortvalue.cc +++ b/onnxruntime/test/util/compare_ortvalue.cc @@ -129,6 +129,8 @@ std::pair CompareFloat16Result(const Tensor& outval const size_t size1 = static_cast(expected_value.Shape().Size()); const MLFloat16* expected_output = expected_value.template Data(); const MLFloat16* real_output = outvalue.template Data(); + std::ostringstream oss; + COMPARE_RESULT result = COMPARE_RESULT::SUCCESS; for (size_t di = 0; di != size1; ++di) { float expected = Eigen::half_impl::half_to_float(Eigen::half_impl::__half_raw(expected_output[di].val)); float real = Eigen::half_impl::half_to_float(Eigen::half_impl::__half_raw(real_output[di].val)); @@ -136,13 +138,11 @@ std::pair CompareFloat16Result(const Tensor& outval const double diff = std::fabs(expected - real); const double rtol = per_sample_tolerance + relative_per_sample_tolerance * std::fabs(expected); if (!IsResultCloselyMatch(real, expected, diff, rtol)) { - std::ostringstream oss; - oss << "expected " << expected << ", got " << real << ", diff: " << diff << ", tol=" << rtol; - - return std::make_pair(COMPARE_RESULT::RESULT_DIFFERS, oss.str()); + oss << "idx: " << di << "expected " << expected << ", got " << real << ", diff: " << diff << ", tol=" << rtol << "\n"; + result = COMPARE_RESULT::RESULT_DIFFERS; } } - return std::make_pair(COMPARE_RESULT::SUCCESS, ""); + return std::make_pair(result, oss.str()); } std::pair CompareBFloat16Result(const Tensor& outvalue, const Tensor& expected_value, diff --git a/orttraining/orttraining/test/training_ops/cuda/cross_entropy_test.cc b/orttraining/orttraining/test/training_ops/cuda/cross_entropy_test.cc index f748969ac8..b4cffa119e 100644 --- a/orttraining/orttraining/test/training_ops/cuda/cross_entropy_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/cross_entropy_test.cc @@ -285,7 +285,9 @@ static void TestSoftmaxCrossEntropyLoss(const std::vector* X_dims, const std::vector* Y_dims, const std::vector* log_prob_dims, const std::string& reduction, - const std::int64_t ignore_index = -1) { + const std::int64_t ignore_index = -1, + const bool test_fp16 = false, + const double error_tolerance = 1e-4) { CompareOpTester test("SoftmaxCrossEntropyLoss", 12, onnxruntime::kOnnxDomain); test.AddAttribute("reduction", reduction); test.AddAttribute("ignore_index", ignore_index); @@ -298,22 +300,47 @@ static void TestSoftmaxCrossEntropyLoss(const std::vector* X_dims, if (index_data.size() > 0) { index_data[0] = ignore_index; } + if (test_fp16) { + std::vector X_data_half(X_data.size()); + ConvertFloatToMLFloat16(X_data.data(), X_data_half.data(), int(X_data.size())); + test.AddInput("X", *X_dims, X_data_half); + } else { + test.AddInput("X", *X_dims, X_data); + } - test.AddInput("X", *X_dims, X_data); test.AddInput("index", *index_dims, index_data); if (weight_dims) { std::vector weight_data = random.Uniform(*weight_dims, 0.0f, 1.0f); - test.AddInput("weight", *weight_dims, weight_data); + if (test_fp16) { + std::vector weight_data_half(weight_data.size()); + ConvertFloatToMLFloat16(weight_data.data(), weight_data_half.data(), int(weight_data.size())); + test.AddInput("weight", *weight_dims, weight_data_half); + } else { + test.AddInput("weight", *weight_dims, weight_data); + } } - std::vector Y_data = FillZeros(*Y_dims); - std::vector log_prob_data = FillZeros(*log_prob_dims); + if (test_fp16) { + std::vector Y_data = FillZeros(*Y_dims); + test.AddOutput("output", *Y_dims, Y_data); - test.AddOutput("output", *Y_dims, Y_data); - test.AddOutput("log_prob", *log_prob_dims, log_prob_data); + if (log_prob_dims) { + std::vector log_prob_data = FillZeros(*log_prob_dims); + test.AddOutput("log_prob", *log_prob_dims, log_prob_data); + } - test.CompareWithCPU(kGpuExecutionProvider); + test.CompareWithCPU(kGpuExecutionProvider, error_tolerance, error_tolerance, true); + } else { + std::vector Y_data = FillZeros(*Y_dims); + test.AddOutput("output", *Y_dims, Y_data); + + if (log_prob_dims) { + std::vector log_prob_data = FillZeros(*log_prob_dims); + test.AddOutput("log_prob", *log_prob_dims, log_prob_data); + } + test.CompareWithCPU(kGpuExecutionProvider); + } } TEST(CudaKernelTest, SoftmaxCrossEntropyLoss_TinySizeTensor) { @@ -339,6 +366,29 @@ TEST(CudaKernelTest, SoftmaxCrossEntropyLoss_TinySizeTensor) { TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, nullptr, &Y_dims_none, &log_prob_dims, "none", 0); } +TEST(CudaKernelTest, SoftmaxCrossEntropyLoss_TinySizeTensor_half) { + std::vector X_dims{8, 2}; + std::vector index_dims{8}; + std::vector weight_dims{2}; + std::vector Y_dims{}; + std::vector Y_dims_none{8}; + std::vector log_prob_dims{8, 2}; + TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, &weight_dims, &Y_dims, &log_prob_dims, "mean", -1, true, 5e-2); + TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, nullptr, &Y_dims, &log_prob_dims, "mean", -1, true, 5e-2); + TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, &weight_dims, &Y_dims, &log_prob_dims, "sum", -1, true, 5e-2); + TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, nullptr, &Y_dims, &log_prob_dims, "sum", -1, true, 5e-2); + TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, &weight_dims, &Y_dims_none, &log_prob_dims, "none", -1, true, 5e-2); + TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, nullptr, &Y_dims_none, &log_prob_dims, "none", -1, true, 5e-2); + + // Just test ignore_index for small tensor because it will increase test time a lot with little verification gain. + TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, &weight_dims, &Y_dims, &log_prob_dims, "mean", 0, true, 5e-2); + TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, nullptr, &Y_dims, &log_prob_dims, "mean", 0, true, 5e-2); + TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, &weight_dims, &Y_dims, &log_prob_dims, "sum", 0, true, 5e-2); + TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, nullptr, &Y_dims, &log_prob_dims, "sum", 0, true, 5e-2); + TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, &weight_dims, &Y_dims_none, &log_prob_dims, "none", 0, true, 5e-2); + TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, nullptr, &Y_dims_none, &log_prob_dims, "none", 0, true, 5e-2); +} + TEST(CudaKernelTest, SoftmaxCrossEntropyLoss_SmallSizeTensor) { std::vector X_dims{8, 20, 10}; std::vector index_dims{8, 10}; @@ -354,6 +404,21 @@ TEST(CudaKernelTest, SoftmaxCrossEntropyLoss_SmallSizeTensor) { TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, nullptr, &Y_dims_none, &log_prob_dims, "none"); } +TEST(CudaKernelTest, SoftmaxCrossEntropyLoss_SmallSizeTensor_half) { + std::vector X_dims{8, 20, 10}; + std::vector index_dims{8, 10}; + std::vector weight_dims{20}; + std::vector Y_dims{}; + std::vector Y_dims_none{8, 10}; + std::vector log_prob_dims{8, 20, 10}; + TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, &weight_dims, &Y_dims, &log_prob_dims, "mean", -1, true, 5e-2); + TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, nullptr, &Y_dims, &log_prob_dims, "mean", -1, true, 5e-2); + TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, &weight_dims, &Y_dims, &log_prob_dims, "sum", -1, true, 5e-2); + TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, nullptr, &Y_dims, &log_prob_dims, "sum", -1, true, 5e-2); + TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, &weight_dims, &Y_dims_none, &log_prob_dims, "none", -1, true, 5e-2); + TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, nullptr, &Y_dims_none, &log_prob_dims, "none", -1, true, 5e-2); +} + TEST(CudaKernelTest, SoftmaxCrossEntropyLoss_MediumSizeTensor) { std::vector X_dims{8, 1024}; std::vector index_dims{8}; @@ -369,6 +434,21 @@ TEST(CudaKernelTest, SoftmaxCrossEntropyLoss_MediumSizeTensor) { TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, nullptr, &Y_dims_none, &log_prob_dims, "none"); } +TEST(CudaKernelTest, SoftmaxCrossEntropyLoss_MediumSizeTensor_half) { + std::vector X_dims{8, 1024}; + std::vector index_dims{8}; + std::vector weight_dims{1024}; + std::vector Y_dims{}; + std::vector Y_dims_none{8}; + std::vector log_prob_dims{8, 1024}; + TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, &weight_dims, &Y_dims, &log_prob_dims, "mean", -1, true, 5e-2); + TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, nullptr, &Y_dims, &log_prob_dims, "mean", -1, true, 5e-2); + TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, &weight_dims, &Y_dims, &log_prob_dims, "sum", -1, true, 5e-2); + TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, nullptr, &Y_dims, &log_prob_dims, "sum", -1, true, 5e-2); + TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, &weight_dims, &Y_dims_none, &log_prob_dims, "none", -1, true, 5e-2); + TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, nullptr, &Y_dims_none, &log_prob_dims, "none", -1, true, 5e-2); +} + // TODO fix flaky test // failing random seed: 2873512643 TEST(CudaKernelTest, DISABLED_SoftmaxCrossEntropyLoss_LargeSizeTensor) { @@ -391,7 +471,9 @@ static void TestSoftmaxCrossEntropyLossGrad(const std::vector& dY_dims, const std::vector& index_dims, const std::vector& dX_dims, const std::string& reduction, - const std::int64_t ignore_index = -1) { + const std::int64_t ignore_index = -1, + const bool test_fp16 = false, + const double error_tolerance = 1e-4) { CompareOpTester test("SoftmaxCrossEntropyLossGrad", 1, onnxruntime::kMSDomain); test.AddAttribute("reduction", reduction); test.AddAttribute("ignore_index", ignore_index); @@ -405,16 +487,31 @@ static void TestSoftmaxCrossEntropyLossGrad(const std::vector& dY_dims, if (index_data.size() > 0) { index_data[0] = ignore_index; } + if (test_fp16) { + std::vector dY_data_half(dY_data.size()); + ConvertFloatToMLFloat16(dY_data.data(), dY_data_half.data(), int(dY_data.size())); + test.AddInput("dY", dY_dims, dY_data_half); - test.AddInput("dY", dY_dims, dY_data); - test.AddInput("log_prob", log_prob_dims, log_prob_data); - test.AddInput("index", index_dims, index_data); + std::vector log_prob_data_half(log_prob_data.size()); + ConvertFloatToMLFloat16(log_prob_data.data(), log_prob_data_half.data(), int(log_prob_data.size())); + test.AddInput("log_prob", log_prob_dims, log_prob_data_half); - std::vector dX_data = FillZeros(dX_dims); + test.AddInput("index", index_dims, index_data); - test.AddOutput("dX", dX_dims, dX_data); + std::vector dX_data = FillZeros(dX_dims); - test.CompareWithCPU(kGpuExecutionProvider); + test.AddOutput("dX", dX_dims, dX_data); + test.CompareWithCPU(kGpuExecutionProvider, error_tolerance, error_tolerance); + } else { + test.AddInput("dY", dY_dims, dY_data); + test.AddInput("log_prob", log_prob_dims, log_prob_data); + test.AddInput("index", index_dims, index_data); + + std::vector dX_data = FillZeros(dX_dims); + + test.AddOutput("dX", dX_dims, dX_data); + test.CompareWithCPU(kGpuExecutionProvider); + } } TEST(CudaKernelTest, SoftmaxCrossEntropyLossGrad_TinySizeTensor) { @@ -452,5 +549,40 @@ TEST(CudaKernelTest, SoftmaxCrossEntropyLossGrad_LargeSizeTensor) { TestSoftmaxCrossEntropyLossGrad({2, 30528}, log_prob_dims, index_dims, dX_dims, "none"); } +TEST(CudaKernelTest, SoftmaxCrossEntropyLossGrad_TinySizeTensor_half) { + std::vector dY_dims{}; + std::vector log_prob_dims{8, 2}; + std::vector index_dims{8}; + std::vector dX_dims{8, 2}; + TestSoftmaxCrossEntropyLossGrad(dY_dims, log_prob_dims, index_dims, dX_dims, "mean", -1, true, 5e-2); + TestSoftmaxCrossEntropyLossGrad(dY_dims, log_prob_dims, index_dims, dX_dims, "sum", -1, true, 5e-2); + TestSoftmaxCrossEntropyLossGrad({8}, log_prob_dims, index_dims, dX_dims, "none", -1, true, 5e-2); + + // Just test ignore_index for small tensor because it will increase test time a lot with little verification gain. + TestSoftmaxCrossEntropyLossGrad(dY_dims, log_prob_dims, index_dims, dX_dims, "mean", 0, true, 5e-2); + TestSoftmaxCrossEntropyLossGrad(dY_dims, log_prob_dims, index_dims, dX_dims, "sum", 0, true, 5e-2); + TestSoftmaxCrossEntropyLossGrad({8}, log_prob_dims, index_dims, dX_dims, "none", 0, true, 5e-2); +} + +TEST(CudaKernelTest, SoftmaxCrossEntropyLossGrad_SmallSizeTensor_half) { + std::vector dY_dims{}; + std::vector log_prob_dims{8, 20, 10}; + std::vector index_dims{8, 10}; + std::vector dX_dims{8, 20, 10}; + TestSoftmaxCrossEntropyLossGrad(dY_dims, log_prob_dims, index_dims, dX_dims, "mean", -1, true, 5e-2); + TestSoftmaxCrossEntropyLossGrad(dY_dims, log_prob_dims, index_dims, dX_dims, "sum", -1, true, 5e-2); + TestSoftmaxCrossEntropyLossGrad({8, 10}, log_prob_dims, index_dims, dX_dims, "none", -1, true, 5e-2); +} + +TEST(CudaKernelTest, SoftmaxCrossEntropyLossGrad_LargeSizeTensor_half) { + std::vector dY_dims{}; + std::vector log_prob_dims{2, 512, 30528}; + std::vector index_dims{2, 30528}; + std::vector dX_dims{2, 512, 30528}; + TestSoftmaxCrossEntropyLossGrad(dY_dims, log_prob_dims, index_dims, dX_dims, "mean", -1, true, 5e-2); + TestSoftmaxCrossEntropyLossGrad(dY_dims, log_prob_dims, index_dims, dX_dims, "sum", -1, true, 5e-2); + TestSoftmaxCrossEntropyLossGrad({2, 30528}, log_prob_dims, index_dims, dX_dims, "none", -1, true, 5e-2); +} + } // namespace test } // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc index d1eeb6832a..f5e7e62c09 100644 --- a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc @@ -53,9 +53,12 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, float, int64_t, SparseSoftmaxCrossEntropy); // class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, float, int32_t, SparseSoftmaxCrossEntropyGrad); class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, float, int64_t, SparseSoftmaxCrossEntropyGrad); +class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, int64_t, SoftmaxCrossEntropyLoss); class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, float, int64_t, SoftmaxCrossEntropyLoss); +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, int64_t, SoftmaxCrossEntropyLoss); class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, int64_t, SoftmaxCrossEntropyLoss); class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, int64_t, SoftmaxCrossEntropyLossGrad); +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, int64_t, SoftmaxCrossEntropyLossGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SoftmaxGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, SoftmaxGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, SoftmaxGrad); @@ -254,9 +257,12 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/orttraining/orttraining/training_ops/cuda/loss/softmax_cross_entropy_loss_impl.cc b/orttraining/orttraining/training_ops/cuda/loss/softmax_cross_entropy_loss_impl.cc index 6b8d1084ca..a860497fdd 100644 --- a/orttraining/orttraining/training_ops/cuda/loss/softmax_cross_entropy_loss_impl.cc +++ b/orttraining/orttraining/training_ops/cuda/loss/softmax_cross_entropy_loss_impl.cc @@ -4,6 +4,7 @@ #include "core/providers/cuda/math/softmax.h" #include "core/providers/cuda/reduction/reduction_functions.h" #include "core/providers/cuda/tensor/transpose.h" +#include "core/providers/cuda/math/unary_elementwise_ops_impl.h" #include "core/providers/cpu/controlflow/scan_utils.h" #include "orttraining/training_ops/cpu/loss/softmax_cross_entropy_loss.h" #include "orttraining/training_ops/cuda/loss/softmax_cross_entropy_loss_impl.h" @@ -37,6 +38,7 @@ namespace cuda { template Status SoftmaxCrossEntropyLoss::ComputeInternal(OpKernelContext* ctx) const { + typedef typename ToCudaType::MappedType CudaT; const Tensor& logit = *ctx->Input(0); const Tensor& label = *ctx->Input(1); const TensorShape logit_shape{logit.Shape()}; @@ -108,38 +110,44 @@ Status SoftmaxCrossEntropyLoss::ComputeInternal(OpKernelContext* ctx) co IAllocatorUniquePtr weight_data_nd = GetScratchBuffer(N_D); T* weight_data_nd_data = weight_data_nd.get(); CUDA_RETURN_IF_ERROR(cudaMemsetAsync(weight_data_nd_data, 0, N_D * sizeof(T), Stream())); - ComputeWeightsSoftmaxCrossEntropyImpl(Stream(), label_data, weight_data, N_D, C, ignore_index_, weight_data_nd_data); + ComputeWeightsSoftmaxCrossEntropyImpl(Stream(), + label_data, + reinterpret_cast(weight_data), + N_D, C, + ignore_index_, + reinterpret_cast(weight_data_nd_data)); // Compute buffer size in byte for reduction APIs. const auto buffer_size = - compute_reduction_buffer_size(static_cast(N_D)); + compute_reduction_buffer_size(static_cast(N_D)); // Allocate reduction buffer whose size is buffer_size bytes, or nullptr if no reduction. IAllocatorUniquePtr reduction_buffer = GetScratchBuffer( reduction_ != ReductionType::NONE ? buffer_size : 0); - auto normalize_factor_data = GetScratchBuffer(1); + typedef AccumulationType_t TBuf; + auto normalize_factor_data = GetScratchBuffer(1); if (reduction_ == ReductionType::MEAN) { ORT_RETURN_IF_ERROR(reduce_sum( Stream(), - weight_data_nd_data, + reinterpret_cast(weight_data_nd_data), normalize_factor_data.get(), static_cast(N_D), reduction_buffer.get(), buffer_size)); } else { - const T normalize_factor = static_cast(1); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(normalize_factor_data.get(), &normalize_factor, sizeof(T), cudaMemcpyHostToDevice, Stream())); + const TBuf normalize_factor = static_cast(1.0f); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(normalize_factor_data.get(), &normalize_factor, sizeof(TBuf), cudaMemcpyHostToDevice, Stream())); } SoftmaxCrossEntropyLossImpl(Stream(), - log_prob_data, + reinterpret_cast(log_prob_data), label_data, - weight_data_nd_data, + reinterpret_cast(weight_data_nd_data), normalize_factor_data.get(), N_D, C, ignore_index_, - tmp_loss_sample_buffer); + reinterpret_cast(tmp_loss_sample_buffer)); // Transpose log probability from [N, D1, D2...Dk, C] to [N, C, D1, D2 .. Dk]. if (logit_shape.NumDimensions() > 2 && log_prob != nullptr) { @@ -159,8 +167,8 @@ Status SoftmaxCrossEntropyLoss::ComputeInternal(OpKernelContext* ctx) co // ReduceSum on loss_per_sample ORT_RETURN_IF_ERROR(reduce_sum( Stream(), - tmp_loss_sample_buffer, - total_loss_data, + reinterpret_cast(tmp_loss_sample_buffer), + reinterpret_cast(total_loss_data), static_cast(N_D), reduction_buffer.get(), buffer_size)); @@ -171,6 +179,7 @@ Status SoftmaxCrossEntropyLoss::ComputeInternal(OpKernelContext* ctx) co template Status SoftmaxCrossEntropyLossGrad::ComputeInternal(OpKernelContext* ctx) const { + typedef typename ToCudaType::MappedType CudaT; const Tensor& dY = *ctx->Input(0); const Tensor& log_prob = *ctx->Input(1); const Tensor& label = *ctx->Input(2); @@ -212,37 +221,43 @@ Status SoftmaxCrossEntropyLossGrad::ComputeInternal(OpKernelContext* ctx IAllocatorUniquePtr weight_data_nd = GetScratchBuffer(N_D); T* weight_data_nd_data = weight_data_nd.get(); CUDA_RETURN_IF_ERROR(cudaMemsetAsync(weight_data_nd_data, 0, N_D * sizeof(T), Stream())); - ComputeWeightsSoftmaxCrossEntropyImpl(Stream(), label_data, weight_data, N_D, C, ignore_index_, weight_data_nd_data); - auto normalize_factor_data = GetScratchBuffer(1); + ComputeWeightsSoftmaxCrossEntropyImpl(Stream(), + label_data, + reinterpret_cast(weight_data), + N_D, C, + ignore_index_, + reinterpret_cast(weight_data_nd_data)); + typedef AccumulationType_t TBuf; + auto normalize_factor_data = GetScratchBuffer(1); if (reduction_ == ReductionType::MEAN) { // Compute buffer size in byte for reduction APIs. const auto buffer_size = - compute_reduction_buffer_size(static_cast(N_D)); + compute_reduction_buffer_size(static_cast(N_D)); // Allocate reduction buffer whose size is buffer_size bytes. IAllocatorUniquePtr reduction_buffer = GetScratchBuffer( buffer_size); ORT_RETURN_IF_ERROR(reduce_sum( Stream(), - weight_data_nd_data, + reinterpret_cast(weight_data_nd_data), normalize_factor_data.get(), static_cast(N_D), reduction_buffer.get(), buffer_size)); } else { - const T normalize_factor = static_cast(1); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(normalize_factor_data.get(), &normalize_factor, sizeof(T), cudaMemcpyHostToDevice, Stream())); + const TBuf normalize_factor = static_cast(1.0f); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(normalize_factor_data.get(), &normalize_factor, sizeof(TBuf), cudaMemcpyHostToDevice, Stream())); } SoftmaxCrossEntropyLossGradImpl(Stream(), - dY_data, - log_prob_data, + reinterpret_cast(dY_data), + reinterpret_cast(log_prob_data), label_data, - weight_data_nd_data, + reinterpret_cast(weight_data_nd_data), normalize_factor_data.get(), N_D, C, ReductionType::NONE == reduction_, - d_logit_data); + reinterpret_cast(d_logit_data)); // Transpose logit from [N, D1, D2...Dk, C] to [N, C, D1, D2 .. Dk] if (probability_shape.NumDimensions() > 2) { @@ -261,16 +276,19 @@ Status SoftmaxCrossEntropyLossGrad::ComputeInternal(OpKernelContext* ctx return Status::OK(); } -#define SPECIALIZED_VERSIONED_COMPUTE_SPARSE(Class, T, Tin, domain, startver, endvar) \ +#define INSTANTIATE_VERSIONED_COMPUTE_SPARSE(Class, T, Tin, domain, startver, endvar) \ REGISTER_KERNEL_VERSIONED_TYPED_TWO_TYPES(Class, T, Tin, domain, startver, endvar) -#define SPECIALIZED_COMPUTE_SPARSE(Class, T, Tin, domain, version) \ +#define INSTANTIATE_COMPUTE_SPARSE(Class, T, Tin, domain, version) \ REGISTER_KERNEL_TYPED_TWO_TYPES(Class, T, Tin, domain, version) \ template Status Class::ComputeInternal(OpKernelContext* ctx) const; -SPECIALIZED_VERSIONED_COMPUTE_SPARSE(SoftmaxCrossEntropyLoss, float, int64_t, kOnnxDomain, 12, 12) -SPECIALIZED_COMPUTE_SPARSE(SoftmaxCrossEntropyLoss, float, int64_t, kOnnxDomain, 13) -SPECIALIZED_COMPUTE_SPARSE(SoftmaxCrossEntropyLossGrad, float, int64_t, kMSDomain, 1) +INSTANTIATE_VERSIONED_COMPUTE_SPARSE(SoftmaxCrossEntropyLoss, float, int64_t, kOnnxDomain, 12, 12) +INSTANTIATE_VERSIONED_COMPUTE_SPARSE(SoftmaxCrossEntropyLoss, MLFloat16, int64_t, kOnnxDomain, 12, 12) +INSTANTIATE_COMPUTE_SPARSE(SoftmaxCrossEntropyLoss, float, int64_t, kOnnxDomain, 13) +INSTANTIATE_COMPUTE_SPARSE(SoftmaxCrossEntropyLoss, MLFloat16, int64_t, kOnnxDomain, 13) +INSTANTIATE_COMPUTE_SPARSE(SoftmaxCrossEntropyLossGrad, float, int64_t, kMSDomain, 1) +INSTANTIATE_COMPUTE_SPARSE(SoftmaxCrossEntropyLossGrad, MLFloat16, int64_t, kMSDomain, 1) } // namespace cuda } // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/loss/softmax_cross_entropy_loss_impl.cu b/orttraining/orttraining/training_ops/cuda/loss/softmax_cross_entropy_loss_impl.cu index 02bd7ebf25..82565b8ab9 100644 --- a/orttraining/orttraining/training_ops/cuda/loss/softmax_cross_entropy_loss_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/loss/softmax_cross_entropy_loss_impl.cu @@ -17,9 +17,10 @@ __global__ void _ComputeWeightsSoftmaxCrossEntropy( CUDA_LONG C, CUDA_LONG ignore_index) { CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(i, N_D); + const T ONE_T = 1; if (label_data[i] != ignore_index) { CUDA_KERNEL_ASSERT(label_data[i] >= 0 && label_data[i] < C); - weight_data_nd[i] = weight_data != nullptr ? weight_data[label_data[i]] : 1; + weight_data_nd[i] = weight_data != nullptr ? weight_data[label_data[i]] : ONE_T; } } @@ -45,12 +46,12 @@ void ComputeWeightsSoftmaxCrossEntropyImpl( II); } -template +template __global__ void _WeightedSoftmaxCrossEntropyLoss( const T* log_prob_data, const Tin* label_data, const T* weight_data, - const T* normalize_factor_data, + const TAcc* normalize_factor_data, T* output_data, CUDA_LONG N_D, CUDA_LONG C, @@ -60,17 +61,18 @@ __global__ void _WeightedSoftmaxCrossEntropyLoss( output_data[i] = 0; } else { CUDA_KERNEL_ASSERT(label_data[i] >= 0 && label_data[i] < C); - output_data[i] = -log_prob_data[i * C + label_data[i]] * weight_data[i] / (*normalize_factor_data); + output_data[i] = static_cast(static_cast(-log_prob_data[i * C + label_data[i]] * weight_data[i]) / + *normalize_factor_data); } } -template +template void SoftmaxCrossEntropyLossImpl( cudaStream_t stream, const T* log_prob, const Tin* label, const T* weight, - const T* normalize_factor, + const TAcc* normalize_factor, size_t count, size_t label_depth, int64_t ignore_index, @@ -79,7 +81,7 @@ void SoftmaxCrossEntropyLossImpl( CUDA_LONG N_D = static_cast(count); CUDA_LONG C = static_cast(label_depth); CUDA_LONG II = static_cast(ignore_index); - _WeightedSoftmaxCrossEntropyLoss<<>>( + _WeightedSoftmaxCrossEntropyLoss<<>>( log_prob, label, weight, @@ -90,28 +92,29 @@ void SoftmaxCrossEntropyLossImpl( II); } -#define SPECIALIZED_IMPL_SoftMaxEntropyLossImpl(T, Tin) \ - template void SoftmaxCrossEntropyLossImpl( \ - cudaStream_t stream, \ - const T* log_prob, \ - const Tin* label, \ - const T* weight, \ - const T* normalize_factor, \ - size_t count, \ - size_t label_depth, \ - int64_t ignore_index, \ +#define INSTANTIATE_IMPL_SoftMaxEntropyLossImpl(T, TAcc, Tin) \ + template void SoftmaxCrossEntropyLossImpl( \ + cudaStream_t stream, \ + const T* log_prob, \ + const Tin* label, \ + const T* weight, \ + const TAcc* normalize_factor, \ + size_t count, \ + size_t label_depth, \ + int64_t ignore_index, \ T* output_data); -SPECIALIZED_IMPL_SoftMaxEntropyLossImpl(float, int32_t) -SPECIALIZED_IMPL_SoftMaxEntropyLossImpl(float, int64_t) +INSTANTIATE_IMPL_SoftMaxEntropyLossImpl(float, float, int32_t) +INSTANTIATE_IMPL_SoftMaxEntropyLossImpl(float, float, int64_t) +INSTANTIATE_IMPL_SoftMaxEntropyLossImpl(half, float, int64_t) -template +template __global__ void _WeightedSoftmaxCrossEntropyLossGrad( const T* dY, const T* log_prob, const Tin* label, const T* weight, - const T* normalize_factor, + const TAcc* normalize_factor, T* output_data, CUDA_LONG N_D, CUDA_LONG C) { @@ -119,24 +122,29 @@ __global__ void _WeightedSoftmaxCrossEntropyLossGrad( int row = i / C; int d = i % C; - CUDA_KERNEL_ASSERT(weight[row] == 0 || (label[row] >= 0 && label[row] < C)); - if(0 == *normalize_factor){ - // normalize_factor is sum of labels' weights. Because zero - // sum implies all weights are 0, the loss function should + const T ZERO_T = 0; + const TAcc ZERO_TAcc = 0; + const TAcc ONE_TAcc = 1; + CUDA_KERNEL_ASSERT(weight[row] == ZERO_T || (label[row] >= 0 && label[row] < C)); + if (ZERO_TAcc == *normalize_factor) { + // normalize_factor is sum of labels' weights. Because zero + // sum implies all weights are 0, the loss function should // be constant 0 and its corresponding gradient should be 0 as well. - output_data[i] = 0; + output_data[i] = ZERO_T; } else { - output_data[i] = (*dY) * weight[row] * (_Exp(log_prob[i]) - 1.0 * (d == label[row])) / (*normalize_factor); + output_data[i] = static_cast(static_cast((*dY) * weight[row]) * + (_Exp(static_cast(log_prob[i])) - ONE_TAcc * (TAcc)(d == label[row])) / + (*normalize_factor)); } } -template +template __global__ void _WeightedReductionNoneSoftmaxCrossEntropyLossGrad( const T* dY, const T* log_prob, const Tin* label, const T* weight, - const T* normalize_factor, + const TAcc* normalize_factor, T* output_data, CUDA_LONG N_D, CUDA_LONG C) { @@ -144,25 +152,30 @@ __global__ void _WeightedReductionNoneSoftmaxCrossEntropyLossGrad( int row = i / C; int d = i % C; - CUDA_KERNEL_ASSERT(weight[row] == 0 || (label[row] >= 0 && label[row] < C)); - if(0 == *normalize_factor){ - // normalize_factor is sum of labels' weights. Because zero - // sum implies all weights are 0, the loss function should + const T ZERO_T = 0; + const TAcc ZERO_TAcc = 0; + const TAcc ONE_TAcc = 1; + CUDA_KERNEL_ASSERT(weight[row] == ZERO_T || (label[row] >= 0 && label[row] < C)); + if (ZERO_TAcc == *normalize_factor) { + // normalize_factor is sum of labels' weights. Because zero + // sum implies all weights are 0, the loss function should // be constant 0 and its corresponding gradient should be 0 as well. - output_data[i] = 0; + output_data[i] = ZERO_T; } else { - output_data[i] = dY[row] * weight[row] * (_Exp(log_prob[i]) - 1.0 * (d == label[row])) / (*normalize_factor); + output_data[i] = static_cast(static_cast(dY[row] * weight[row]) * + (_Exp(static_cast(log_prob[i])) - ONE_TAcc * (TAcc)(d == label[row])) / + (*normalize_factor)); } } -template +template void SoftmaxCrossEntropyLossGradImpl( cudaStream_t stream, const T* dY, const T* log_prob, const Tin* label, const T* weight, - const T* normalize_factor, + const TAcc* normalize_factor, size_t count, size_t label_depth, bool reduction_none, @@ -171,7 +184,7 @@ void SoftmaxCrossEntropyLossGradImpl( CUDA_LONG C = static_cast(label_depth); int blocksPerGrid = (int)(ceil(static_cast(N_D * C) / GridDim::maxThreadsPerBlock)); if (reduction_none) { - _WeightedReductionNoneSoftmaxCrossEntropyLossGrad<<>>( + _WeightedReductionNoneSoftmaxCrossEntropyLossGrad<<>>( dY, log_prob, label, @@ -181,7 +194,7 @@ void SoftmaxCrossEntropyLossGradImpl( N_D, C); } else { - _WeightedSoftmaxCrossEntropyLossGrad<<>>( + _WeightedSoftmaxCrossEntropyLossGrad<<>>( dY, log_prob, label, @@ -193,23 +206,24 @@ void SoftmaxCrossEntropyLossGradImpl( } } -#define SPECIALIZED_IMPL_SoftMaxEntropyLossGradImpl(T, Tin) \ - template void SoftmaxCrossEntropyLossGradImpl( \ - cudaStream_t stream, \ - const T* dY, \ - const T* log_prob, \ - const Tin* label, \ - const T* weight, \ - const T* normalize_factor, \ - size_t count, \ - size_t label_depth, \ - bool reducation_none, \ +#define INSTANTIATE_IMPL_SoftMaxEntropyLossGradImpl(T, TAcc, Tin) \ + template void SoftmaxCrossEntropyLossGradImpl( \ + cudaStream_t stream, \ + const T* dY, \ + const T* log_prob, \ + const Tin* label, \ + const T* weight, \ + const TAcc* normalize_factor, \ + size_t count, \ + size_t label_depth, \ + bool reducation_none, \ T* output_data); -SPECIALIZED_IMPL_SoftMaxEntropyLossGradImpl(float, int32_t) -SPECIALIZED_IMPL_SoftMaxEntropyLossGradImpl(float, int64_t) +INSTANTIATE_IMPL_SoftMaxEntropyLossGradImpl(float, float, int32_t) +INSTANTIATE_IMPL_SoftMaxEntropyLossGradImpl(float, float, int64_t) +INSTANTIATE_IMPL_SoftMaxEntropyLossGradImpl(half, float, int64_t) -#define SPECIALIZED_IMPL_ComputeWeightsSoftmaxCrossEntropyImpl(T, Tin) \ +#define INSTANTIATE_IMPL_ComputeWeightsSoftmaxCrossEntropyImpl(T, Tin) \ template void ComputeWeightsSoftmaxCrossEntropyImpl( \ cudaStream_t stream, \ const Tin* label, \ @@ -219,8 +233,9 @@ SPECIALIZED_IMPL_SoftMaxEntropyLossGradImpl(float, int64_t) int64_t ignore_index, \ T* weight_data_nd); -SPECIALIZED_IMPL_ComputeWeightsSoftmaxCrossEntropyImpl(float, int32_t) -SPECIALIZED_IMPL_ComputeWeightsSoftmaxCrossEntropyImpl(float, int64_t) +INSTANTIATE_IMPL_ComputeWeightsSoftmaxCrossEntropyImpl(float, int32_t) +INSTANTIATE_IMPL_ComputeWeightsSoftmaxCrossEntropyImpl(float, int64_t) +INSTANTIATE_IMPL_ComputeWeightsSoftmaxCrossEntropyImpl(half, int64_t) } // namespace cuda } // namespace onnxruntime \ No newline at end of file diff --git a/orttraining/orttraining/training_ops/cuda/loss/softmax_cross_entropy_loss_impl.h b/orttraining/orttraining/training_ops/cuda/loss/softmax_cross_entropy_loss_impl.h index d368fe9fbd..26604b00cc 100644 --- a/orttraining/orttraining/training_ops/cuda/loss/softmax_cross_entropy_loss_impl.h +++ b/orttraining/orttraining/training_ops/cuda/loss/softmax_cross_entropy_loss_impl.h @@ -10,26 +10,26 @@ namespace onnxruntime { namespace cuda { -template +template void SoftmaxCrossEntropyLossImpl( cudaStream_t stream, const T* log_prob, const Tin* label, const T* weight, - const T* normalize_factor, + const TAcc* normalize_factor, size_t count, size_t label_depth, int64_t ignore_index, T* output_data); -template +template void SoftmaxCrossEntropyLossGradImpl( cudaStream_t stream, const T* dY, const T* log_prob, const Tin* label, const T* weight, - const T* normalize_factor, + const TAcc* normalize_factor, size_t count, size_t label_depth, bool reduction_none, diff --git a/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc b/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc index 8388fa79d5..148cac7fa8 100644 --- a/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc +++ b/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc @@ -52,8 +52,11 @@ class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDom // class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, float, int32_t, SparseSoftmaxCrossEntropyGrad); class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, float, int64_t, SparseSoftmaxCrossEntropyGrad); class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, float, int64_t, SoftmaxCrossEntropyLoss); +class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, int64_t, SoftmaxCrossEntropyLoss); class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, int64_t, SoftmaxCrossEntropyLoss); +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, int64_t, SoftmaxCrossEntropyLoss); class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, int64_t, SoftmaxCrossEntropyLossGrad); +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, int64_t, SoftmaxCrossEntropyLossGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SoftmaxGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, SoftmaxGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SoftmaxGrad); @@ -193,8 +196,11 @@ Status RegisterRocmTrainingKernels(KernelRegistry& kernel_registry) { // BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, BuildKernelCreateInfo,