mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Add MLFloat16 support for SoftmaxCrossEntropyLoss for CUDA EP (#7679)
* Forward op changes * Add tests, improve kernel * add opset 13 registration, remove unnecessary changes * Add fp16 grad for SCELoss, review comments
This commit is contained in:
parent
39fac6d304
commit
bfbcc89db1
10 changed files with 316 additions and 110 deletions
|
|
@ -343,6 +343,7 @@ Status reduce_mean(
|
|||
|
||||
#define INSTANTIATE_REDUCE_SUM(TIn, TOut) \
|
||||
template Status reduce_sum<TIn, TOut>(cudaStream_t stream, const TIn* input, TOut* output, int size, void* buffer, size_t buffer_size)
|
||||
INSTANTIATE_REDUCE_SUM(half, half);
|
||||
INSTANTIATE_REDUCE_SUM(half, float);
|
||||
INSTANTIATE_REDUCE_SUM(float, float);
|
||||
INSTANTIATE_REDUCE_SUM(double, double);
|
||||
|
|
|
|||
|
|
@ -2,9 +2,10 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/session/inference_session.h"
|
||||
#include "core/optimizer/insert_cast_transformer.h"
|
||||
#include "test/util/include/default_providers.h"
|
||||
#include "test/providers/compare_provider_test_utils.h"
|
||||
|
||||
#include "test/test_environment.h"
|
||||
#include "test/compare_ortvalue.h"
|
||||
|
||||
using namespace std;
|
||||
|
|
@ -41,7 +42,8 @@ std::unique_ptr<IExecutionProvider> GetExecutionProvider(const std::string& prov
|
|||
|
||||
void CompareOpTester::CompareWithCPU(const std::string& target_provider_type,
|
||||
double per_sample_tolerance,
|
||||
double relative_per_sample_tolerance) {
|
||||
double relative_per_sample_tolerance,
|
||||
const bool need_cpu_cast) {
|
||||
#ifndef NDEBUG
|
||||
run_called_ = true;
|
||||
#endif
|
||||
|
|
@ -52,7 +54,21 @@ void CompareOpTester::CompareWithCPU(const std::string& target_provider_type,
|
|||
auto p_model = BuildGraph();
|
||||
auto& graph = p_model->MainGraph();
|
||||
|
||||
Status status = graph.Resolve();
|
||||
Status status;
|
||||
|
||||
// In InferenceSession::Initialize(), the call to graph partitioner, which is responsible
|
||||
// for Inlining function bodies for ops whose kernel is missing happens before the
|
||||
// Cast Transformer. As a result, for MLFloat16 tests where the node is missing a CPU kernel,
|
||||
// the function body is instead used for CPU pass. This option allows the comparison with
|
||||
// the CPU kernel by adding the input/output casts before looking for a registered CPU kernel.
|
||||
if (need_cpu_cast) {
|
||||
InsertCastTransformer transformer("Test");
|
||||
bool modified = false;
|
||||
status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger());
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
}
|
||||
|
||||
status = graph.Resolve();
|
||||
ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
if (!status.IsOK()) {
|
||||
return;
|
||||
|
|
@ -102,11 +118,22 @@ void CompareOpTester::CompareWithCPU(const std::string& target_provider_type,
|
|||
}
|
||||
|
||||
// run with target provider
|
||||
// build the graph again as the cpu graph may be with casts
|
||||
auto p_tp_model = BuildGraph();
|
||||
auto& tp_graph = p_tp_model->MainGraph();
|
||||
|
||||
status = tp_graph.Resolve();
|
||||
ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
if (!status.IsOK()) {
|
||||
return;
|
||||
}
|
||||
|
||||
InferenceSession target_session_object{so, GetEnvironment()};
|
||||
EXPECT_TRUE(target_session_object.RegisterExecutionProvider(std::move(target_execution_provider)).IsOK());
|
||||
|
||||
std::istringstream model_proto_str1(s1);
|
||||
std::string s2;
|
||||
p_tp_model->ToProto().SerializeToString(&s2);
|
||||
std::istringstream model_proto_str1(s2);
|
||||
status = target_session_object.Load(model_proto_str1);
|
||||
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
if (!status.IsOK()) {
|
||||
|
|
|
|||
|
|
@ -19,7 +19,8 @@ class CompareOpTester : public OpTester {
|
|||
|
||||
void CompareWithCPU(const std::string& target_provider_type,
|
||||
double per_sample_tolerance = 1e-4,
|
||||
double relative_per_sample_tolerance = 1e-4);
|
||||
double relative_per_sample_tolerance = 1e-4,
|
||||
const bool need_cpu_cast = false);
|
||||
};
|
||||
|
||||
} // namespace test
|
||||
|
|
|
|||
|
|
@ -129,6 +129,8 @@ std::pair<COMPARE_RESULT, std::string> CompareFloat16Result(const Tensor& outval
|
|||
const size_t size1 = static_cast<size_t>(expected_value.Shape().Size());
|
||||
const MLFloat16* expected_output = expected_value.template Data<MLFloat16>();
|
||||
const MLFloat16* real_output = outvalue.template Data<MLFloat16>();
|
||||
std::ostringstream oss;
|
||||
COMPARE_RESULT result = COMPARE_RESULT::SUCCESS;
|
||||
for (size_t di = 0; di != size1; ++di) {
|
||||
float expected = Eigen::half_impl::half_to_float(Eigen::half_impl::__half_raw(expected_output[di].val));
|
||||
float real = Eigen::half_impl::half_to_float(Eigen::half_impl::__half_raw(real_output[di].val));
|
||||
|
|
@ -136,13 +138,11 @@ std::pair<COMPARE_RESULT, std::string> CompareFloat16Result(const Tensor& outval
|
|||
const double diff = std::fabs(expected - real);
|
||||
const double rtol = per_sample_tolerance + relative_per_sample_tolerance * std::fabs(expected);
|
||||
if (!IsResultCloselyMatch<float>(real, expected, diff, rtol)) {
|
||||
std::ostringstream oss;
|
||||
oss << "expected " << expected << ", got " << real << ", diff: " << diff << ", tol=" << rtol;
|
||||
|
||||
return std::make_pair(COMPARE_RESULT::RESULT_DIFFERS, oss.str());
|
||||
oss << "idx: " << di << "expected " << expected << ", got " << real << ", diff: " << diff << ", tol=" << rtol << "\n";
|
||||
result = COMPARE_RESULT::RESULT_DIFFERS;
|
||||
}
|
||||
}
|
||||
return std::make_pair(COMPARE_RESULT::SUCCESS, "");
|
||||
return std::make_pair(result, oss.str());
|
||||
}
|
||||
|
||||
std::pair<COMPARE_RESULT, std::string> CompareBFloat16Result(const Tensor& outvalue, const Tensor& expected_value,
|
||||
|
|
|
|||
|
|
@ -285,7 +285,9 @@ static void TestSoftmaxCrossEntropyLoss(const std::vector<int64_t>* X_dims,
|
|||
const std::vector<int64_t>* Y_dims,
|
||||
const std::vector<int64_t>* log_prob_dims,
|
||||
const std::string& reduction,
|
||||
const std::int64_t ignore_index = -1) {
|
||||
const std::int64_t ignore_index = -1,
|
||||
const bool test_fp16 = false,
|
||||
const double error_tolerance = 1e-4) {
|
||||
CompareOpTester test("SoftmaxCrossEntropyLoss", 12, onnxruntime::kOnnxDomain);
|
||||
test.AddAttribute("reduction", reduction);
|
||||
test.AddAttribute("ignore_index", ignore_index);
|
||||
|
|
@ -298,22 +300,47 @@ static void TestSoftmaxCrossEntropyLoss(const std::vector<int64_t>* X_dims,
|
|||
if (index_data.size() > 0) {
|
||||
index_data[0] = ignore_index;
|
||||
}
|
||||
if (test_fp16) {
|
||||
std::vector<MLFloat16> X_data_half(X_data.size());
|
||||
ConvertFloatToMLFloat16(X_data.data(), X_data_half.data(), int(X_data.size()));
|
||||
test.AddInput<MLFloat16>("X", *X_dims, X_data_half);
|
||||
} else {
|
||||
test.AddInput<float>("X", *X_dims, X_data);
|
||||
}
|
||||
|
||||
test.AddInput<float>("X", *X_dims, X_data);
|
||||
test.AddInput<int64_t>("index", *index_dims, index_data);
|
||||
|
||||
if (weight_dims) {
|
||||
std::vector<float> weight_data = random.Uniform<float>(*weight_dims, 0.0f, 1.0f);
|
||||
test.AddInput<float>("weight", *weight_dims, weight_data);
|
||||
if (test_fp16) {
|
||||
std::vector<MLFloat16> weight_data_half(weight_data.size());
|
||||
ConvertFloatToMLFloat16(weight_data.data(), weight_data_half.data(), int(weight_data.size()));
|
||||
test.AddInput<MLFloat16>("weight", *weight_dims, weight_data_half);
|
||||
} else {
|
||||
test.AddInput<float>("weight", *weight_dims, weight_data);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<float> Y_data = FillZeros<float>(*Y_dims);
|
||||
std::vector<float> log_prob_data = FillZeros<float>(*log_prob_dims);
|
||||
if (test_fp16) {
|
||||
std::vector<MLFloat16> Y_data = FillZeros<MLFloat16>(*Y_dims);
|
||||
test.AddOutput<MLFloat16>("output", *Y_dims, Y_data);
|
||||
|
||||
test.AddOutput<float>("output", *Y_dims, Y_data);
|
||||
test.AddOutput<float>("log_prob", *log_prob_dims, log_prob_data);
|
||||
if (log_prob_dims) {
|
||||
std::vector<MLFloat16> log_prob_data = FillZeros<MLFloat16>(*log_prob_dims);
|
||||
test.AddOutput<MLFloat16>("log_prob", *log_prob_dims, log_prob_data);
|
||||
}
|
||||
|
||||
test.CompareWithCPU(kGpuExecutionProvider);
|
||||
test.CompareWithCPU(kGpuExecutionProvider, error_tolerance, error_tolerance, true);
|
||||
} else {
|
||||
std::vector<float> Y_data = FillZeros<float>(*Y_dims);
|
||||
test.AddOutput<float>("output", *Y_dims, Y_data);
|
||||
|
||||
if (log_prob_dims) {
|
||||
std::vector<float> log_prob_data = FillZeros<float>(*log_prob_dims);
|
||||
test.AddOutput<float>("log_prob", *log_prob_dims, log_prob_data);
|
||||
}
|
||||
test.CompareWithCPU(kGpuExecutionProvider);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CudaKernelTest, SoftmaxCrossEntropyLoss_TinySizeTensor) {
|
||||
|
|
@ -339,6 +366,29 @@ TEST(CudaKernelTest, SoftmaxCrossEntropyLoss_TinySizeTensor) {
|
|||
TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, nullptr, &Y_dims_none, &log_prob_dims, "none", 0);
|
||||
}
|
||||
|
||||
TEST(CudaKernelTest, SoftmaxCrossEntropyLoss_TinySizeTensor_half) {
|
||||
std::vector<int64_t> X_dims{8, 2};
|
||||
std::vector<int64_t> index_dims{8};
|
||||
std::vector<int64_t> weight_dims{2};
|
||||
std::vector<int64_t> Y_dims{};
|
||||
std::vector<int64_t> Y_dims_none{8};
|
||||
std::vector<int64_t> log_prob_dims{8, 2};
|
||||
TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, &weight_dims, &Y_dims, &log_prob_dims, "mean", -1, true, 5e-2);
|
||||
TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, nullptr, &Y_dims, &log_prob_dims, "mean", -1, true, 5e-2);
|
||||
TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, &weight_dims, &Y_dims, &log_prob_dims, "sum", -1, true, 5e-2);
|
||||
TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, nullptr, &Y_dims, &log_prob_dims, "sum", -1, true, 5e-2);
|
||||
TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, &weight_dims, &Y_dims_none, &log_prob_dims, "none", -1, true, 5e-2);
|
||||
TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, nullptr, &Y_dims_none, &log_prob_dims, "none", -1, true, 5e-2);
|
||||
|
||||
// Just test ignore_index for small tensor because it will increase test time a lot with little verification gain.
|
||||
TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, &weight_dims, &Y_dims, &log_prob_dims, "mean", 0, true, 5e-2);
|
||||
TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, nullptr, &Y_dims, &log_prob_dims, "mean", 0, true, 5e-2);
|
||||
TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, &weight_dims, &Y_dims, &log_prob_dims, "sum", 0, true, 5e-2);
|
||||
TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, nullptr, &Y_dims, &log_prob_dims, "sum", 0, true, 5e-2);
|
||||
TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, &weight_dims, &Y_dims_none, &log_prob_dims, "none", 0, true, 5e-2);
|
||||
TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, nullptr, &Y_dims_none, &log_prob_dims, "none", 0, true, 5e-2);
|
||||
}
|
||||
|
||||
TEST(CudaKernelTest, SoftmaxCrossEntropyLoss_SmallSizeTensor) {
|
||||
std::vector<int64_t> X_dims{8, 20, 10};
|
||||
std::vector<int64_t> index_dims{8, 10};
|
||||
|
|
@ -354,6 +404,21 @@ TEST(CudaKernelTest, SoftmaxCrossEntropyLoss_SmallSizeTensor) {
|
|||
TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, nullptr, &Y_dims_none, &log_prob_dims, "none");
|
||||
}
|
||||
|
||||
TEST(CudaKernelTest, SoftmaxCrossEntropyLoss_SmallSizeTensor_half) {
|
||||
std::vector<int64_t> X_dims{8, 20, 10};
|
||||
std::vector<int64_t> index_dims{8, 10};
|
||||
std::vector<int64_t> weight_dims{20};
|
||||
std::vector<int64_t> Y_dims{};
|
||||
std::vector<int64_t> Y_dims_none{8, 10};
|
||||
std::vector<int64_t> log_prob_dims{8, 20, 10};
|
||||
TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, &weight_dims, &Y_dims, &log_prob_dims, "mean", -1, true, 5e-2);
|
||||
TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, nullptr, &Y_dims, &log_prob_dims, "mean", -1, true, 5e-2);
|
||||
TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, &weight_dims, &Y_dims, &log_prob_dims, "sum", -1, true, 5e-2);
|
||||
TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, nullptr, &Y_dims, &log_prob_dims, "sum", -1, true, 5e-2);
|
||||
TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, &weight_dims, &Y_dims_none, &log_prob_dims, "none", -1, true, 5e-2);
|
||||
TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, nullptr, &Y_dims_none, &log_prob_dims, "none", -1, true, 5e-2);
|
||||
}
|
||||
|
||||
TEST(CudaKernelTest, SoftmaxCrossEntropyLoss_MediumSizeTensor) {
|
||||
std::vector<int64_t> X_dims{8, 1024};
|
||||
std::vector<int64_t> index_dims{8};
|
||||
|
|
@ -369,6 +434,21 @@ TEST(CudaKernelTest, SoftmaxCrossEntropyLoss_MediumSizeTensor) {
|
|||
TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, nullptr, &Y_dims_none, &log_prob_dims, "none");
|
||||
}
|
||||
|
||||
TEST(CudaKernelTest, SoftmaxCrossEntropyLoss_MediumSizeTensor_half) {
|
||||
std::vector<int64_t> X_dims{8, 1024};
|
||||
std::vector<int64_t> index_dims{8};
|
||||
std::vector<int64_t> weight_dims{1024};
|
||||
std::vector<int64_t> Y_dims{};
|
||||
std::vector<int64_t> Y_dims_none{8};
|
||||
std::vector<int64_t> log_prob_dims{8, 1024};
|
||||
TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, &weight_dims, &Y_dims, &log_prob_dims, "mean", -1, true, 5e-2);
|
||||
TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, nullptr, &Y_dims, &log_prob_dims, "mean", -1, true, 5e-2);
|
||||
TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, &weight_dims, &Y_dims, &log_prob_dims, "sum", -1, true, 5e-2);
|
||||
TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, nullptr, &Y_dims, &log_prob_dims, "sum", -1, true, 5e-2);
|
||||
TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, &weight_dims, &Y_dims_none, &log_prob_dims, "none", -1, true, 5e-2);
|
||||
TestSoftmaxCrossEntropyLoss(&X_dims, &index_dims, nullptr, &Y_dims_none, &log_prob_dims, "none", -1, true, 5e-2);
|
||||
}
|
||||
|
||||
// TODO fix flaky test
|
||||
// failing random seed: 2873512643
|
||||
TEST(CudaKernelTest, DISABLED_SoftmaxCrossEntropyLoss_LargeSizeTensor) {
|
||||
|
|
@ -391,7 +471,9 @@ static void TestSoftmaxCrossEntropyLossGrad(const std::vector<int64_t>& dY_dims,
|
|||
const std::vector<int64_t>& index_dims,
|
||||
const std::vector<int64_t>& dX_dims,
|
||||
const std::string& reduction,
|
||||
const std::int64_t ignore_index = -1) {
|
||||
const std::int64_t ignore_index = -1,
|
||||
const bool test_fp16 = false,
|
||||
const double error_tolerance = 1e-4) {
|
||||
CompareOpTester test("SoftmaxCrossEntropyLossGrad", 1, onnxruntime::kMSDomain);
|
||||
test.AddAttribute("reduction", reduction);
|
||||
test.AddAttribute("ignore_index", ignore_index);
|
||||
|
|
@ -405,16 +487,31 @@ static void TestSoftmaxCrossEntropyLossGrad(const std::vector<int64_t>& dY_dims,
|
|||
if (index_data.size() > 0) {
|
||||
index_data[0] = ignore_index;
|
||||
}
|
||||
if (test_fp16) {
|
||||
std::vector<MLFloat16> dY_data_half(dY_data.size());
|
||||
ConvertFloatToMLFloat16(dY_data.data(), dY_data_half.data(), int(dY_data.size()));
|
||||
test.AddInput<MLFloat16>("dY", dY_dims, dY_data_half);
|
||||
|
||||
test.AddInput<float>("dY", dY_dims, dY_data);
|
||||
test.AddInput<float>("log_prob", log_prob_dims, log_prob_data);
|
||||
test.AddInput<int64_t>("index", index_dims, index_data);
|
||||
std::vector<MLFloat16> log_prob_data_half(log_prob_data.size());
|
||||
ConvertFloatToMLFloat16(log_prob_data.data(), log_prob_data_half.data(), int(log_prob_data.size()));
|
||||
test.AddInput<MLFloat16>("log_prob", log_prob_dims, log_prob_data_half);
|
||||
|
||||
std::vector<float> dX_data = FillZeros<float>(dX_dims);
|
||||
test.AddInput<int64_t>("index", index_dims, index_data);
|
||||
|
||||
test.AddOutput<float>("dX", dX_dims, dX_data);
|
||||
std::vector<MLFloat16> dX_data = FillZeros<MLFloat16>(dX_dims);
|
||||
|
||||
test.CompareWithCPU(kGpuExecutionProvider);
|
||||
test.AddOutput<MLFloat16>("dX", dX_dims, dX_data);
|
||||
test.CompareWithCPU(kGpuExecutionProvider, error_tolerance, error_tolerance);
|
||||
} else {
|
||||
test.AddInput<float>("dY", dY_dims, dY_data);
|
||||
test.AddInput<float>("log_prob", log_prob_dims, log_prob_data);
|
||||
test.AddInput<int64_t>("index", index_dims, index_data);
|
||||
|
||||
std::vector<float> dX_data = FillZeros<float>(dX_dims);
|
||||
|
||||
test.AddOutput<float>("dX", dX_dims, dX_data);
|
||||
test.CompareWithCPU(kGpuExecutionProvider);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CudaKernelTest, SoftmaxCrossEntropyLossGrad_TinySizeTensor) {
|
||||
|
|
@ -452,5 +549,40 @@ TEST(CudaKernelTest, SoftmaxCrossEntropyLossGrad_LargeSizeTensor) {
|
|||
TestSoftmaxCrossEntropyLossGrad({2, 30528}, log_prob_dims, index_dims, dX_dims, "none");
|
||||
}
|
||||
|
||||
TEST(CudaKernelTest, SoftmaxCrossEntropyLossGrad_TinySizeTensor_half) {
|
||||
std::vector<int64_t> dY_dims{};
|
||||
std::vector<int64_t> log_prob_dims{8, 2};
|
||||
std::vector<int64_t> index_dims{8};
|
||||
std::vector<int64_t> dX_dims{8, 2};
|
||||
TestSoftmaxCrossEntropyLossGrad(dY_dims, log_prob_dims, index_dims, dX_dims, "mean", -1, true, 5e-2);
|
||||
TestSoftmaxCrossEntropyLossGrad(dY_dims, log_prob_dims, index_dims, dX_dims, "sum", -1, true, 5e-2);
|
||||
TestSoftmaxCrossEntropyLossGrad({8}, log_prob_dims, index_dims, dX_dims, "none", -1, true, 5e-2);
|
||||
|
||||
// Just test ignore_index for small tensor because it will increase test time a lot with little verification gain.
|
||||
TestSoftmaxCrossEntropyLossGrad(dY_dims, log_prob_dims, index_dims, dX_dims, "mean", 0, true, 5e-2);
|
||||
TestSoftmaxCrossEntropyLossGrad(dY_dims, log_prob_dims, index_dims, dX_dims, "sum", 0, true, 5e-2);
|
||||
TestSoftmaxCrossEntropyLossGrad({8}, log_prob_dims, index_dims, dX_dims, "none", 0, true, 5e-2);
|
||||
}
|
||||
|
||||
TEST(CudaKernelTest, SoftmaxCrossEntropyLossGrad_SmallSizeTensor_half) {
|
||||
std::vector<int64_t> dY_dims{};
|
||||
std::vector<int64_t> log_prob_dims{8, 20, 10};
|
||||
std::vector<int64_t> index_dims{8, 10};
|
||||
std::vector<int64_t> dX_dims{8, 20, 10};
|
||||
TestSoftmaxCrossEntropyLossGrad(dY_dims, log_prob_dims, index_dims, dX_dims, "mean", -1, true, 5e-2);
|
||||
TestSoftmaxCrossEntropyLossGrad(dY_dims, log_prob_dims, index_dims, dX_dims, "sum", -1, true, 5e-2);
|
||||
TestSoftmaxCrossEntropyLossGrad({8, 10}, log_prob_dims, index_dims, dX_dims, "none", -1, true, 5e-2);
|
||||
}
|
||||
|
||||
TEST(CudaKernelTest, SoftmaxCrossEntropyLossGrad_LargeSizeTensor_half) {
|
||||
std::vector<int64_t> dY_dims{};
|
||||
std::vector<int64_t> log_prob_dims{2, 512, 30528};
|
||||
std::vector<int64_t> index_dims{2, 30528};
|
||||
std::vector<int64_t> dX_dims{2, 512, 30528};
|
||||
TestSoftmaxCrossEntropyLossGrad(dY_dims, log_prob_dims, index_dims, dX_dims, "mean", -1, true, 5e-2);
|
||||
TestSoftmaxCrossEntropyLossGrad(dY_dims, log_prob_dims, index_dims, dX_dims, "sum", -1, true, 5e-2);
|
||||
TestSoftmaxCrossEntropyLossGrad({2, 30528}, log_prob_dims, index_dims, dX_dims, "none", -1, true, 5e-2);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -53,9 +53,12 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
|
|||
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, float, int64_t, SparseSoftmaxCrossEntropy);
|
||||
// class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, float, int32_t, SparseSoftmaxCrossEntropyGrad);
|
||||
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, float, int64_t, SparseSoftmaxCrossEntropyGrad);
|
||||
class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, int64_t, SoftmaxCrossEntropyLoss);
|
||||
class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, float, int64_t, SoftmaxCrossEntropyLoss);
|
||||
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, int64_t, SoftmaxCrossEntropyLoss);
|
||||
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, int64_t, SoftmaxCrossEntropyLoss);
|
||||
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, int64_t, SoftmaxCrossEntropyLossGrad);
|
||||
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, int64_t, SoftmaxCrossEntropyLossGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SoftmaxGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, SoftmaxGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, SoftmaxGrad);
|
||||
|
|
@ -254,9 +257,12 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, LogSoftmaxGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, LogSoftmaxGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, LogSoftmaxGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, int64_t, SoftmaxCrossEntropyLoss)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, float, int64_t, SoftmaxCrossEntropyLoss)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, int64_t, SoftmaxCrossEntropyLoss)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, int64_t, SoftmaxCrossEntropyLoss)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, int64_t, SoftmaxCrossEntropyLossGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, int64_t, SoftmaxCrossEntropyLossGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, BatchNormalizationGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, BatchNormalizationGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ConvGrad)>,
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
#include "core/providers/cuda/math/softmax.h"
|
||||
#include "core/providers/cuda/reduction/reduction_functions.h"
|
||||
#include "core/providers/cuda/tensor/transpose.h"
|
||||
#include "core/providers/cuda/math/unary_elementwise_ops_impl.h"
|
||||
#include "core/providers/cpu/controlflow/scan_utils.h"
|
||||
#include "orttraining/training_ops/cpu/loss/softmax_cross_entropy_loss.h"
|
||||
#include "orttraining/training_ops/cuda/loss/softmax_cross_entropy_loss_impl.h"
|
||||
|
|
@ -37,6 +38,7 @@ namespace cuda {
|
|||
|
||||
template <typename T, typename Tin>
|
||||
Status SoftmaxCrossEntropyLoss<T, Tin>::ComputeInternal(OpKernelContext* ctx) const {
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
const Tensor& logit = *ctx->Input<Tensor>(0);
|
||||
const Tensor& label = *ctx->Input<Tensor>(1);
|
||||
const TensorShape logit_shape{logit.Shape()};
|
||||
|
|
@ -108,38 +110,44 @@ Status SoftmaxCrossEntropyLoss<T, Tin>::ComputeInternal(OpKernelContext* ctx) co
|
|||
IAllocatorUniquePtr<T> weight_data_nd = GetScratchBuffer<T>(N_D);
|
||||
T* weight_data_nd_data = weight_data_nd.get();
|
||||
CUDA_RETURN_IF_ERROR(cudaMemsetAsync(weight_data_nd_data, 0, N_D * sizeof(T), Stream()));
|
||||
ComputeWeightsSoftmaxCrossEntropyImpl(Stream(), label_data, weight_data, N_D, C, ignore_index_, weight_data_nd_data);
|
||||
ComputeWeightsSoftmaxCrossEntropyImpl(Stream(),
|
||||
label_data,
|
||||
reinterpret_cast<const CudaT*>(weight_data),
|
||||
N_D, C,
|
||||
ignore_index_,
|
||||
reinterpret_cast<CudaT*>(weight_data_nd_data));
|
||||
|
||||
// Compute buffer size in byte for reduction APIs.
|
||||
const auto buffer_size =
|
||||
compute_reduction_buffer_size<T>(static_cast<int>(N_D));
|
||||
compute_reduction_buffer_size<CudaT>(static_cast<int>(N_D));
|
||||
// Allocate reduction buffer whose size is buffer_size bytes, or nullptr if no reduction.
|
||||
IAllocatorUniquePtr<void> reduction_buffer = GetScratchBuffer<void>(
|
||||
reduction_ != ReductionType::NONE ? buffer_size : 0);
|
||||
|
||||
auto normalize_factor_data = GetScratchBuffer<T>(1);
|
||||
typedef AccumulationType_t<CudaT> TBuf;
|
||||
auto normalize_factor_data = GetScratchBuffer<TBuf>(1);
|
||||
if (reduction_ == ReductionType::MEAN) {
|
||||
ORT_RETURN_IF_ERROR(reduce_sum(
|
||||
Stream(),
|
||||
weight_data_nd_data,
|
||||
reinterpret_cast<CudaT*>(weight_data_nd_data),
|
||||
normalize_factor_data.get(),
|
||||
static_cast<int>(N_D),
|
||||
reduction_buffer.get(),
|
||||
buffer_size));
|
||||
} else {
|
||||
const T normalize_factor = static_cast<T>(1);
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(normalize_factor_data.get(), &normalize_factor, sizeof(T), cudaMemcpyHostToDevice, Stream()));
|
||||
const TBuf normalize_factor = static_cast<TBuf>(1.0f);
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(normalize_factor_data.get(), &normalize_factor, sizeof(TBuf), cudaMemcpyHostToDevice, Stream()));
|
||||
}
|
||||
|
||||
SoftmaxCrossEntropyLossImpl(Stream(),
|
||||
log_prob_data,
|
||||
reinterpret_cast<CudaT*>(log_prob_data),
|
||||
label_data,
|
||||
weight_data_nd_data,
|
||||
reinterpret_cast<CudaT*>(weight_data_nd_data),
|
||||
normalize_factor_data.get(),
|
||||
N_D,
|
||||
C,
|
||||
ignore_index_,
|
||||
tmp_loss_sample_buffer);
|
||||
reinterpret_cast<CudaT*>(tmp_loss_sample_buffer));
|
||||
|
||||
// Transpose log probability from [N, D1, D2...Dk, C] to [N, C, D1, D2 .. Dk].
|
||||
if (logit_shape.NumDimensions() > 2 && log_prob != nullptr) {
|
||||
|
|
@ -159,8 +167,8 @@ Status SoftmaxCrossEntropyLoss<T, Tin>::ComputeInternal(OpKernelContext* ctx) co
|
|||
// ReduceSum on loss_per_sample
|
||||
ORT_RETURN_IF_ERROR(reduce_sum(
|
||||
Stream(),
|
||||
tmp_loss_sample_buffer,
|
||||
total_loss_data,
|
||||
reinterpret_cast<CudaT*>(tmp_loss_sample_buffer),
|
||||
reinterpret_cast<CudaT*>(total_loss_data),
|
||||
static_cast<int>(N_D),
|
||||
reduction_buffer.get(),
|
||||
buffer_size));
|
||||
|
|
@ -171,6 +179,7 @@ Status SoftmaxCrossEntropyLoss<T, Tin>::ComputeInternal(OpKernelContext* ctx) co
|
|||
|
||||
template <typename T, typename Tin>
|
||||
Status SoftmaxCrossEntropyLossGrad<T, Tin>::ComputeInternal(OpKernelContext* ctx) const {
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
const Tensor& dY = *ctx->Input<Tensor>(0);
|
||||
const Tensor& log_prob = *ctx->Input<Tensor>(1);
|
||||
const Tensor& label = *ctx->Input<Tensor>(2);
|
||||
|
|
@ -212,37 +221,43 @@ Status SoftmaxCrossEntropyLossGrad<T, Tin>::ComputeInternal(OpKernelContext* ctx
|
|||
IAllocatorUniquePtr<T> weight_data_nd = GetScratchBuffer<T>(N_D);
|
||||
T* weight_data_nd_data = weight_data_nd.get();
|
||||
CUDA_RETURN_IF_ERROR(cudaMemsetAsync(weight_data_nd_data, 0, N_D * sizeof(T), Stream()));
|
||||
ComputeWeightsSoftmaxCrossEntropyImpl(Stream(), label_data, weight_data, N_D, C, ignore_index_, weight_data_nd_data);
|
||||
auto normalize_factor_data = GetScratchBuffer<T>(1);
|
||||
ComputeWeightsSoftmaxCrossEntropyImpl(Stream(),
|
||||
label_data,
|
||||
reinterpret_cast<const CudaT*>(weight_data),
|
||||
N_D, C,
|
||||
ignore_index_,
|
||||
reinterpret_cast<CudaT*>(weight_data_nd_data));
|
||||
typedef AccumulationType_t<CudaT> TBuf;
|
||||
auto normalize_factor_data = GetScratchBuffer<TBuf>(1);
|
||||
if (reduction_ == ReductionType::MEAN) {
|
||||
// Compute buffer size in byte for reduction APIs.
|
||||
const auto buffer_size =
|
||||
compute_reduction_buffer_size<T>(static_cast<int>(N_D));
|
||||
compute_reduction_buffer_size<CudaT>(static_cast<int>(N_D));
|
||||
// Allocate reduction buffer whose size is buffer_size bytes.
|
||||
IAllocatorUniquePtr<void> reduction_buffer = GetScratchBuffer<void>(
|
||||
buffer_size);
|
||||
ORT_RETURN_IF_ERROR(reduce_sum(
|
||||
Stream(),
|
||||
weight_data_nd_data,
|
||||
reinterpret_cast<const CudaT*>(weight_data_nd_data),
|
||||
normalize_factor_data.get(),
|
||||
static_cast<int>(N_D),
|
||||
reduction_buffer.get(),
|
||||
buffer_size));
|
||||
} else {
|
||||
const T normalize_factor = static_cast<T>(1);
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(normalize_factor_data.get(), &normalize_factor, sizeof(T), cudaMemcpyHostToDevice, Stream()));
|
||||
const TBuf normalize_factor = static_cast<TBuf>(1.0f);
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(normalize_factor_data.get(), &normalize_factor, sizeof(TBuf), cudaMemcpyHostToDevice, Stream()));
|
||||
}
|
||||
|
||||
SoftmaxCrossEntropyLossGradImpl(Stream(),
|
||||
dY_data,
|
||||
log_prob_data,
|
||||
reinterpret_cast<const CudaT*>(dY_data),
|
||||
reinterpret_cast<const CudaT*>(log_prob_data),
|
||||
label_data,
|
||||
weight_data_nd_data,
|
||||
reinterpret_cast<const CudaT*>(weight_data_nd_data),
|
||||
normalize_factor_data.get(),
|
||||
N_D,
|
||||
C,
|
||||
ReductionType::NONE == reduction_,
|
||||
d_logit_data);
|
||||
reinterpret_cast<CudaT*>(d_logit_data));
|
||||
|
||||
// Transpose logit from [N, D1, D2...Dk, C] to [N, C, D1, D2 .. Dk]
|
||||
if (probability_shape.NumDimensions() > 2) {
|
||||
|
|
@ -261,16 +276,19 @@ Status SoftmaxCrossEntropyLossGrad<T, Tin>::ComputeInternal(OpKernelContext* ctx
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
#define SPECIALIZED_VERSIONED_COMPUTE_SPARSE(Class, T, Tin, domain, startver, endvar) \
|
||||
#define INSTANTIATE_VERSIONED_COMPUTE_SPARSE(Class, T, Tin, domain, startver, endvar) \
|
||||
REGISTER_KERNEL_VERSIONED_TYPED_TWO_TYPES(Class, T, Tin, domain, startver, endvar)
|
||||
|
||||
#define SPECIALIZED_COMPUTE_SPARSE(Class, T, Tin, domain, version) \
|
||||
#define INSTANTIATE_COMPUTE_SPARSE(Class, T, Tin, domain, version) \
|
||||
REGISTER_KERNEL_TYPED_TWO_TYPES(Class, T, Tin, domain, version) \
|
||||
template Status Class<T, Tin>::ComputeInternal(OpKernelContext* ctx) const;
|
||||
|
||||
SPECIALIZED_VERSIONED_COMPUTE_SPARSE(SoftmaxCrossEntropyLoss, float, int64_t, kOnnxDomain, 12, 12)
|
||||
SPECIALIZED_COMPUTE_SPARSE(SoftmaxCrossEntropyLoss, float, int64_t, kOnnxDomain, 13)
|
||||
SPECIALIZED_COMPUTE_SPARSE(SoftmaxCrossEntropyLossGrad, float, int64_t, kMSDomain, 1)
|
||||
INSTANTIATE_VERSIONED_COMPUTE_SPARSE(SoftmaxCrossEntropyLoss, float, int64_t, kOnnxDomain, 12, 12)
|
||||
INSTANTIATE_VERSIONED_COMPUTE_SPARSE(SoftmaxCrossEntropyLoss, MLFloat16, int64_t, kOnnxDomain, 12, 12)
|
||||
INSTANTIATE_COMPUTE_SPARSE(SoftmaxCrossEntropyLoss, float, int64_t, kOnnxDomain, 13)
|
||||
INSTANTIATE_COMPUTE_SPARSE(SoftmaxCrossEntropyLoss, MLFloat16, int64_t, kOnnxDomain, 13)
|
||||
INSTANTIATE_COMPUTE_SPARSE(SoftmaxCrossEntropyLossGrad, float, int64_t, kMSDomain, 1)
|
||||
INSTANTIATE_COMPUTE_SPARSE(SoftmaxCrossEntropyLossGrad, MLFloat16, int64_t, kMSDomain, 1)
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -17,9 +17,10 @@ __global__ void _ComputeWeightsSoftmaxCrossEntropy(
|
|||
CUDA_LONG C,
|
||||
CUDA_LONG ignore_index) {
|
||||
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(i, N_D);
|
||||
const T ONE_T = 1;
|
||||
if (label_data[i] != ignore_index) {
|
||||
CUDA_KERNEL_ASSERT(label_data[i] >= 0 && label_data[i] < C);
|
||||
weight_data_nd[i] = weight_data != nullptr ? weight_data[label_data[i]] : 1;
|
||||
weight_data_nd[i] = weight_data != nullptr ? weight_data[label_data[i]] : ONE_T;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -45,12 +46,12 @@ void ComputeWeightsSoftmaxCrossEntropyImpl(
|
|||
II);
|
||||
}
|
||||
|
||||
template <typename T, typename Tin>
|
||||
template <typename T, typename TAcc, typename Tin>
|
||||
__global__ void _WeightedSoftmaxCrossEntropyLoss(
|
||||
const T* log_prob_data,
|
||||
const Tin* label_data,
|
||||
const T* weight_data,
|
||||
const T* normalize_factor_data,
|
||||
const TAcc* normalize_factor_data,
|
||||
T* output_data,
|
||||
CUDA_LONG N_D,
|
||||
CUDA_LONG C,
|
||||
|
|
@ -60,17 +61,18 @@ __global__ void _WeightedSoftmaxCrossEntropyLoss(
|
|||
output_data[i] = 0;
|
||||
} else {
|
||||
CUDA_KERNEL_ASSERT(label_data[i] >= 0 && label_data[i] < C);
|
||||
output_data[i] = -log_prob_data[i * C + label_data[i]] * weight_data[i] / (*normalize_factor_data);
|
||||
output_data[i] = static_cast<T>(static_cast<TAcc>(-log_prob_data[i * C + label_data[i]] * weight_data[i]) /
|
||||
*normalize_factor_data);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Tin>
|
||||
template <typename T, typename TAcc, typename Tin>
|
||||
void SoftmaxCrossEntropyLossImpl(
|
||||
cudaStream_t stream,
|
||||
const T* log_prob,
|
||||
const Tin* label,
|
||||
const T* weight,
|
||||
const T* normalize_factor,
|
||||
const TAcc* normalize_factor,
|
||||
size_t count,
|
||||
size_t label_depth,
|
||||
int64_t ignore_index,
|
||||
|
|
@ -79,7 +81,7 @@ void SoftmaxCrossEntropyLossImpl(
|
|||
CUDA_LONG N_D = static_cast<CUDA_LONG>(count);
|
||||
CUDA_LONG C = static_cast<CUDA_LONG>(label_depth);
|
||||
CUDA_LONG II = static_cast<CUDA_LONG>(ignore_index);
|
||||
_WeightedSoftmaxCrossEntropyLoss<T, Tin><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
_WeightedSoftmaxCrossEntropyLoss<T, TAcc, Tin><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
log_prob,
|
||||
label,
|
||||
weight,
|
||||
|
|
@ -90,28 +92,29 @@ void SoftmaxCrossEntropyLossImpl(
|
|||
II);
|
||||
}
|
||||
|
||||
#define SPECIALIZED_IMPL_SoftMaxEntropyLossImpl(T, Tin) \
|
||||
template void SoftmaxCrossEntropyLossImpl( \
|
||||
cudaStream_t stream, \
|
||||
const T* log_prob, \
|
||||
const Tin* label, \
|
||||
const T* weight, \
|
||||
const T* normalize_factor, \
|
||||
size_t count, \
|
||||
size_t label_depth, \
|
||||
int64_t ignore_index, \
|
||||
#define INSTANTIATE_IMPL_SoftMaxEntropyLossImpl(T, TAcc, Tin) \
|
||||
template void SoftmaxCrossEntropyLossImpl( \
|
||||
cudaStream_t stream, \
|
||||
const T* log_prob, \
|
||||
const Tin* label, \
|
||||
const T* weight, \
|
||||
const TAcc* normalize_factor, \
|
||||
size_t count, \
|
||||
size_t label_depth, \
|
||||
int64_t ignore_index, \
|
||||
T* output_data);
|
||||
|
||||
SPECIALIZED_IMPL_SoftMaxEntropyLossImpl(float, int32_t)
|
||||
SPECIALIZED_IMPL_SoftMaxEntropyLossImpl(float, int64_t)
|
||||
INSTANTIATE_IMPL_SoftMaxEntropyLossImpl(float, float, int32_t)
|
||||
INSTANTIATE_IMPL_SoftMaxEntropyLossImpl(float, float, int64_t)
|
||||
INSTANTIATE_IMPL_SoftMaxEntropyLossImpl(half, float, int64_t)
|
||||
|
||||
template <typename T, typename Tin>
|
||||
template <typename T, typename TAcc, typename Tin>
|
||||
__global__ void _WeightedSoftmaxCrossEntropyLossGrad(
|
||||
const T* dY,
|
||||
const T* log_prob,
|
||||
const Tin* label,
|
||||
const T* weight,
|
||||
const T* normalize_factor,
|
||||
const TAcc* normalize_factor,
|
||||
T* output_data,
|
||||
CUDA_LONG N_D,
|
||||
CUDA_LONG C) {
|
||||
|
|
@ -119,24 +122,29 @@ __global__ void _WeightedSoftmaxCrossEntropyLossGrad(
|
|||
|
||||
int row = i / C;
|
||||
int d = i % C;
|
||||
CUDA_KERNEL_ASSERT(weight[row] == 0 || (label[row] >= 0 && label[row] < C));
|
||||
if(0 == *normalize_factor){
|
||||
// normalize_factor is sum of labels' weights. Because zero
|
||||
// sum implies all weights are 0, the loss function should
|
||||
const T ZERO_T = 0;
|
||||
const TAcc ZERO_TAcc = 0;
|
||||
const TAcc ONE_TAcc = 1;
|
||||
CUDA_KERNEL_ASSERT(weight[row] == ZERO_T || (label[row] >= 0 && label[row] < C));
|
||||
if (ZERO_TAcc == *normalize_factor) {
|
||||
// normalize_factor is sum of labels' weights. Because zero
|
||||
// sum implies all weights are 0, the loss function should
|
||||
// be constant 0 and its corresponding gradient should be 0 as well.
|
||||
output_data[i] = 0;
|
||||
output_data[i] = ZERO_T;
|
||||
} else {
|
||||
output_data[i] = (*dY) * weight[row] * (_Exp(log_prob[i]) - 1.0 * (d == label[row])) / (*normalize_factor);
|
||||
output_data[i] = static_cast<T>(static_cast<TAcc>((*dY) * weight[row]) *
|
||||
(_Exp(static_cast<TAcc>(log_prob[i])) - ONE_TAcc * (TAcc)(d == label[row])) /
|
||||
(*normalize_factor));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Tin>
|
||||
template <typename T, typename TAcc, typename Tin>
|
||||
__global__ void _WeightedReductionNoneSoftmaxCrossEntropyLossGrad(
|
||||
const T* dY,
|
||||
const T* log_prob,
|
||||
const Tin* label,
|
||||
const T* weight,
|
||||
const T* normalize_factor,
|
||||
const TAcc* normalize_factor,
|
||||
T* output_data,
|
||||
CUDA_LONG N_D,
|
||||
CUDA_LONG C) {
|
||||
|
|
@ -144,25 +152,30 @@ __global__ void _WeightedReductionNoneSoftmaxCrossEntropyLossGrad(
|
|||
|
||||
int row = i / C;
|
||||
int d = i % C;
|
||||
CUDA_KERNEL_ASSERT(weight[row] == 0 || (label[row] >= 0 && label[row] < C));
|
||||
if(0 == *normalize_factor){
|
||||
// normalize_factor is sum of labels' weights. Because zero
|
||||
// sum implies all weights are 0, the loss function should
|
||||
const T ZERO_T = 0;
|
||||
const TAcc ZERO_TAcc = 0;
|
||||
const TAcc ONE_TAcc = 1;
|
||||
CUDA_KERNEL_ASSERT(weight[row] == ZERO_T || (label[row] >= 0 && label[row] < C));
|
||||
if (ZERO_TAcc == *normalize_factor) {
|
||||
// normalize_factor is sum of labels' weights. Because zero
|
||||
// sum implies all weights are 0, the loss function should
|
||||
// be constant 0 and its corresponding gradient should be 0 as well.
|
||||
output_data[i] = 0;
|
||||
output_data[i] = ZERO_T;
|
||||
} else {
|
||||
output_data[i] = dY[row] * weight[row] * (_Exp(log_prob[i]) - 1.0 * (d == label[row])) / (*normalize_factor);
|
||||
output_data[i] = static_cast<T>(static_cast<TAcc>(dY[row] * weight[row]) *
|
||||
(_Exp(static_cast<TAcc>(log_prob[i])) - ONE_TAcc * (TAcc)(d == label[row])) /
|
||||
(*normalize_factor));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Tin>
|
||||
template <typename T, typename TAcc, typename Tin>
|
||||
void SoftmaxCrossEntropyLossGradImpl(
|
||||
cudaStream_t stream,
|
||||
const T* dY,
|
||||
const T* log_prob,
|
||||
const Tin* label,
|
||||
const T* weight,
|
||||
const T* normalize_factor,
|
||||
const TAcc* normalize_factor,
|
||||
size_t count,
|
||||
size_t label_depth,
|
||||
bool reduction_none,
|
||||
|
|
@ -171,7 +184,7 @@ void SoftmaxCrossEntropyLossGradImpl(
|
|||
CUDA_LONG C = static_cast<CUDA_LONG>(label_depth);
|
||||
int blocksPerGrid = (int)(ceil(static_cast<float>(N_D * C) / GridDim::maxThreadsPerBlock));
|
||||
if (reduction_none) {
|
||||
_WeightedReductionNoneSoftmaxCrossEntropyLossGrad<T, Tin><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
_WeightedReductionNoneSoftmaxCrossEntropyLossGrad<T, TAcc, Tin><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
dY,
|
||||
log_prob,
|
||||
label,
|
||||
|
|
@ -181,7 +194,7 @@ void SoftmaxCrossEntropyLossGradImpl(
|
|||
N_D,
|
||||
C);
|
||||
} else {
|
||||
_WeightedSoftmaxCrossEntropyLossGrad<T, Tin><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
_WeightedSoftmaxCrossEntropyLossGrad<T, TAcc, Tin><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
dY,
|
||||
log_prob,
|
||||
label,
|
||||
|
|
@ -193,23 +206,24 @@ void SoftmaxCrossEntropyLossGradImpl(
|
|||
}
|
||||
}
|
||||
|
||||
#define SPECIALIZED_IMPL_SoftMaxEntropyLossGradImpl(T, Tin) \
|
||||
template void SoftmaxCrossEntropyLossGradImpl( \
|
||||
cudaStream_t stream, \
|
||||
const T* dY, \
|
||||
const T* log_prob, \
|
||||
const Tin* label, \
|
||||
const T* weight, \
|
||||
const T* normalize_factor, \
|
||||
size_t count, \
|
||||
size_t label_depth, \
|
||||
bool reducation_none, \
|
||||
#define INSTANTIATE_IMPL_SoftMaxEntropyLossGradImpl(T, TAcc, Tin) \
|
||||
template void SoftmaxCrossEntropyLossGradImpl( \
|
||||
cudaStream_t stream, \
|
||||
const T* dY, \
|
||||
const T* log_prob, \
|
||||
const Tin* label, \
|
||||
const T* weight, \
|
||||
const TAcc* normalize_factor, \
|
||||
size_t count, \
|
||||
size_t label_depth, \
|
||||
bool reducation_none, \
|
||||
T* output_data);
|
||||
|
||||
SPECIALIZED_IMPL_SoftMaxEntropyLossGradImpl(float, int32_t)
|
||||
SPECIALIZED_IMPL_SoftMaxEntropyLossGradImpl(float, int64_t)
|
||||
INSTANTIATE_IMPL_SoftMaxEntropyLossGradImpl(float, float, int32_t)
|
||||
INSTANTIATE_IMPL_SoftMaxEntropyLossGradImpl(float, float, int64_t)
|
||||
INSTANTIATE_IMPL_SoftMaxEntropyLossGradImpl(half, float, int64_t)
|
||||
|
||||
#define SPECIALIZED_IMPL_ComputeWeightsSoftmaxCrossEntropyImpl(T, Tin) \
|
||||
#define INSTANTIATE_IMPL_ComputeWeightsSoftmaxCrossEntropyImpl(T, Tin) \
|
||||
template void ComputeWeightsSoftmaxCrossEntropyImpl( \
|
||||
cudaStream_t stream, \
|
||||
const Tin* label, \
|
||||
|
|
@ -219,8 +233,9 @@ SPECIALIZED_IMPL_SoftMaxEntropyLossGradImpl(float, int64_t)
|
|||
int64_t ignore_index, \
|
||||
T* weight_data_nd);
|
||||
|
||||
SPECIALIZED_IMPL_ComputeWeightsSoftmaxCrossEntropyImpl(float, int32_t)
|
||||
SPECIALIZED_IMPL_ComputeWeightsSoftmaxCrossEntropyImpl(float, int64_t)
|
||||
INSTANTIATE_IMPL_ComputeWeightsSoftmaxCrossEntropyImpl(float, int32_t)
|
||||
INSTANTIATE_IMPL_ComputeWeightsSoftmaxCrossEntropyImpl(float, int64_t)
|
||||
INSTANTIATE_IMPL_ComputeWeightsSoftmaxCrossEntropyImpl(half, int64_t)
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -10,26 +10,26 @@
|
|||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
template <typename T, typename Tin>
|
||||
template <typename T, typename TAcc, typename Tin>
|
||||
void SoftmaxCrossEntropyLossImpl(
|
||||
cudaStream_t stream,
|
||||
const T* log_prob,
|
||||
const Tin* label,
|
||||
const T* weight,
|
||||
const T* normalize_factor,
|
||||
const TAcc* normalize_factor,
|
||||
size_t count,
|
||||
size_t label_depth,
|
||||
int64_t ignore_index,
|
||||
T* output_data);
|
||||
|
||||
template <typename T, typename Tin>
|
||||
template <typename T, typename TAcc, typename Tin>
|
||||
void SoftmaxCrossEntropyLossGradImpl(
|
||||
cudaStream_t stream,
|
||||
const T* dY,
|
||||
const T* log_prob,
|
||||
const Tin* label,
|
||||
const T* weight,
|
||||
const T* normalize_factor,
|
||||
const TAcc* normalize_factor,
|
||||
size_t count,
|
||||
size_t label_depth,
|
||||
bool reduction_none,
|
||||
|
|
|
|||
|
|
@ -52,8 +52,11 @@ class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDom
|
|||
// class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, float, int32_t, SparseSoftmaxCrossEntropyGrad);
|
||||
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, float, int64_t, SparseSoftmaxCrossEntropyGrad);
|
||||
class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, float, int64_t, SoftmaxCrossEntropyLoss);
|
||||
class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, int64_t, SoftmaxCrossEntropyLoss);
|
||||
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, int64_t, SoftmaxCrossEntropyLoss);
|
||||
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, int64_t, SoftmaxCrossEntropyLoss);
|
||||
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, int64_t, SoftmaxCrossEntropyLossGrad);
|
||||
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, int64_t, SoftmaxCrossEntropyLossGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SoftmaxGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, SoftmaxGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SoftmaxGrad);
|
||||
|
|
@ -193,8 +196,11 @@ Status RegisterRocmTrainingKernels(KernelRegistry& kernel_registry) {
|
|||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, LogSoftmaxGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, LogSoftmaxGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, float, int64_t, SoftmaxCrossEntropyLoss)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, int64_t, SoftmaxCrossEntropyLoss)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, int64_t, SoftmaxCrossEntropyLoss)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, int64_t, SoftmaxCrossEntropyLoss)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, int64_t, SoftmaxCrossEntropyLossGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, int64_t, SoftmaxCrossEntropyLossGrad)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, BatchNormalizationGrad)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, BatchNormalizationGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, GatherGrad)>,
|
||||
|
|
|
|||
Loading…
Reference in a new issue