From da07c83948a456e75f87fadd97de6d5437d2ec7c Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Wed, 14 Sep 2022 14:45:51 +0800 Subject: [PATCH] SoftmaxCrossEntropyLossInternalGrad and Sum Fusion (#12746) * fuse scegrad and sum * add yield output shapes to value_info * resolve comments * fix merge main --- .../core/optimizer/graph_transformer_utils.cc | 2 + .../optimizer/graph_transform_test_builder.h | 17 ++ .../core/framework/ortmodule_graph_builder.cc | 9 + .../core/graph/training_op_defs.cc | 21 +- .../optimizer/sce_loss_grad_bias_fusion.cc | 114 +++++++++ .../optimizer/sce_loss_grad_bias_fusion.h | 24 ++ .../test/optimizer/graph_transform_test.cc | 224 ++++++++++++++++++ .../training_ops/cuda/cross_entropy_test.cc | 72 ++++-- .../cpu/loss/softmax_cross_entropy_loss.cc | 14 +- .../loss/softmax_cross_entropy_loss_impl.cc | 18 +- .../loss/softmax_cross_entropy_loss_impl.cu | 46 ++-- .../loss/softmax_cross_entropy_loss_impl.h | 1 + 12 files changed, 519 insertions(+), 43 deletions(-) create mode 100644 orttraining/orttraining/core/optimizer/sce_loss_grad_bias_fusion.cc create mode 100644 orttraining/orttraining/core/optimizer/sce_loss_grad_bias_fusion.h diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 20c0275bff..4b029f9176 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -65,6 +65,7 @@ #ifdef ENABLE_TRAINING #include "orttraining/core/optimizer/bitmask_dropout_replacement.h" #include "orttraining/core/optimizer/bias_softmax_dropout_fusion.h" +#include "orttraining/core/optimizer/sce_loss_grad_bias_fusion.h" #endif #endif // !defined(ORT_MINIMAL_BUILD) @@ -256,6 +257,7 @@ InlinedVector> GenerateTransformers( #ifdef ENABLE_TRAINING transformers.emplace_back(std::make_unique(cuda_rocm_eps)); transformers.emplace_back(std::make_unique(cuda_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); #endif transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); diff --git a/onnxruntime/test/optimizer/graph_transform_test_builder.h b/onnxruntime/test/optimizer/graph_transform_test_builder.h index 4f57f0b815..341059f2b1 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_builder.h +++ b/onnxruntime/test/optimizer/graph_transform_test_builder.h @@ -96,6 +96,23 @@ class ModelTestBuilder { return &graph_.GetOrCreateNodeArg(name, nullptr); } + template + NodeArg* MakeIntermediate(const std::optional>& shape) { + ONNX_NAMESPACE::TypeProto type_proto; + type_proto.mutable_tensor_type()->set_elem_type(utils::ToTensorProtoElementType()); + if (shape != std::nullopt) { + type_proto.mutable_tensor_type()->mutable_shape(); + for (auto& d : *shape) { + auto dim = type_proto.mutable_tensor_type()->mutable_shape()->add_dim(); + if (d != -1) { + dim->set_dim_value(d); + } + } + } + std::string name = graph_.GenerateNodeArgName("node"); + return &graph_.GetOrCreateNodeArg(name, &type_proto); + } + template NodeArg* MakeInitializer(const std::vector& shape, const std::vector& data) { std::string name = graph_.GenerateNodeArgName("constant"); diff --git a/orttraining/orttraining/core/framework/ortmodule_graph_builder.cc b/orttraining/orttraining/core/framework/ortmodule_graph_builder.cc index 0f917a914a..c1ee149ece 100644 --- a/orttraining/orttraining/core/framework/ortmodule_graph_builder.cc +++ b/orttraining/orttraining/core/framework/ortmodule_graph_builder.cc @@ -448,6 +448,15 @@ void OrtModuleGraphBuilder::FindModuleOutputNeededForBackward() { } } } + + // Graph resolve will have the YieldOp outputs' shapes inferred. To avoid lossing these information when + // transferring model from backend to frontend (in case any graph optimization requires these shape information), + // add them to graph's ValueInfo. + for (const auto& node_def : yield_node->OutputDefs()) { + if (node_def->TypeAsProto()) { + gradient_graph.AddValueInfo(node_def); + } + } } void OrtModuleGraphBuilder::UpdatePythonOpInputsRequireGradInfo( diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index 973a2c098a..84ae3ffd48 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -333,11 +333,23 @@ bool SCELossGradFunBuilder(bool ignore_index_as_attr, const FunctionBodyBuildCon )"); builder.Add(R"( - adj_BCD = CastLike (one_hot_label_BCD, prob_BCD) - grad_BCD = Sub (prob_BCD, adj_BCD) - d_logits_BCD = Mul (d_loss_B1D, grad_BCD) - d_logits = Reshape (d_logits_BCD, orig_shape) + adj_BCD = CastLike (one_hot_label_BCD, prob_BCD) + grad_BCD = Sub (prob_BCD, adj_BCD) + d_logits_BCD = Mul (d_loss_B1D, grad_BCD) )"); + + if (ctx.hasInput(5)) { + builder.Add(R"( + d_logits_without_bias = Reshape (d_logits_BCD, orig_shape) + bias_shaped = Reshape (bias, orig_shape) + d_logits = Add(d_logits_without_bias, bias_shaped) + )"); + } else { + builder.Add(R"( + d_logits = Reshape (d_logits_BCD, orig_shape) + )"); + } + schema.BuildFunction(functionProto); return true; }; @@ -3892,6 +3904,7 @@ Return true if all elements are true and false otherwise. .Input(4, "ignore_index", "Scalar tensor to specify a target value that is ignored and does not contribute to the input gradient.", "I", OpSchema::Optional) + .Input(5, "bias", "data to be non-broadcasting added to the gradient.", "T", OpSchema::Optional) .Output(0, "d_logits", "gradient of logits", "T") .TypeConstraint("T", {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"}, "Constrain to float, float16 and double tensors.") diff --git a/orttraining/orttraining/core/optimizer/sce_loss_grad_bias_fusion.cc b/orttraining/orttraining/core/optimizer/sce_loss_grad_bias_fusion.cc new file mode 100644 index 0000000000..c4af401d93 --- /dev/null +++ b/orttraining/orttraining/core/optimizer/sce_loss_grad_bias_fusion.cc @@ -0,0 +1,114 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "orttraining/core/optimizer/sce_loss_grad_bias_fusion.h" + +#include "core/graph/graph_utils.h" +#include "core/optimizer/initializer.h" +#include "core/optimizer/utils.h" + +namespace onnxruntime { + +/** +Fuse SoftmaxCrossEntropyLossInternalGrad + Reshape(optional) + Sum/Add to SoftmaxCrossEntropyLossInternalGrad. +If it's Sum Op, it requires that it has only 2 inputs. Sum/Add must be non-broadcasting computation. +*/ +Status SceLossGradBiasFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, + const logging::Logger& logger) const { + GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + + for (auto node_index : node_topology_list) { + auto* node_ptr = graph.GetNode(node_index); + if (!node_ptr) continue; // Node was removed. + + auto& node = *node_ptr; + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); + + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "SoftmaxCrossEntropyLossInternalGrad", {1}, kMSDomain) || + !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) || node.InputDefs().size() == 6 || + node.GetOutputEdgesCount() != 1) { + continue; + } + + NodeArg* sce_grad_def = node.MutableOutputDefs()[0]; + Node* p_next = graph.GetNode(node.OutputNodesBegin()->Index()); + Node* p_reshape = nullptr; + if (graph_utils::IsSupportedOptypeVersionAndDomain(*p_next, "Reshape", {5, 13, 14}) && + graph_utils::IsSupportedProvider(*p_next, GetCompatibleExecutionProviders()) && + p_next->GetOutputEdgesCount() == 1) { + p_reshape = p_next; + sce_grad_def = p_reshape->MutableOutputDefs()[0]; + p_next = graph.GetNode(p_next->OutputNodesBegin()->Index()); + } + + if (!(graph_utils::IsSupportedOptypeVersionAndDomain(*p_next, "Add", {7, 13, 14}) || + graph_utils::IsSupportedOptypeVersionAndDomain(*p_next, "Sum", {6, 8, 13})) || + !graph_utils::IsSupportedProvider(*p_next, GetCompatibleExecutionProviders()) || + p_next->InputDefs().size() != 2) { + continue; + } + + Node& sum_node = *p_next; + auto& sum_input_defs = sum_node.MutableInputDefs(); + auto shape0 = sum_input_defs[0]->Shape(); + auto shape1 = sum_input_defs[1]->Shape(); + if (!shape0 || !shape1 || shape0->dim_size() != shape1->dim_size()) { + continue; + } + + bool has_same_shape = true; + for (int i = 0; i < shape0->dim_size(); ++i) { + if (shape0->dim(i) != shape1->dim(i)) { + has_same_shape = false; + break; + } + } + + if (!has_same_shape) continue; + + NodeArg* bias_def = sce_grad_def == sum_input_defs[0] ? sum_input_defs[1] : sum_input_defs[0]; + auto& scegrad_inputs = node.MutableInputDefs(); + InlinedVector new_scegrad_node_inputs{scegrad_inputs[0], scegrad_inputs[1], scegrad_inputs[2]}; + InlinedVector new_scegrad_node_outputs; + if (scegrad_inputs.size() >= 4) { + new_scegrad_node_inputs.emplace_back(scegrad_inputs[3]); + } else { + new_scegrad_node_inputs.emplace_back(&graph.GetOrCreateNodeArg("", nullptr)); + } + if (scegrad_inputs.size() >= 5) { + new_scegrad_node_inputs.emplace_back(scegrad_inputs[4]); + } else { + ONNX_NAMESPACE::TensorProto ignore_index_initializer_proto; + ignore_index_initializer_proto.set_name(graph.GenerateNodeArgName("sce_grad_ignore_index")); + ignore_index_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + ignore_index_initializer_proto.add_int64_data(static_cast(-1)); + new_scegrad_node_inputs.emplace_back(&graph_utils::AddInitializer(graph, ignore_index_initializer_proto)); + } + new_scegrad_node_inputs.emplace_back(bias_def); + if (!p_reshape) { + new_scegrad_node_outputs.emplace_back(sum_node.MutableOutputDefs()[0]); + } else { + new_scegrad_node_outputs.emplace_back(p_reshape->MutableInputDefs()[0]); + } + Node& new_scegrad_node = + graph.AddNode(graph.GenerateNodeName("FusedSoftmaxCrossEntropyLossInternalGrad"), + "SoftmaxCrossEntropyLossInternalGrad", "FusedSoftmaxCrossEntropyLossInternalGrad", + new_scegrad_node_inputs, new_scegrad_node_outputs, &node.GetAttributes(), kMSDomain); + new_scegrad_node.SetExecutionProviderType(node.GetExecutionProviderType()); + + graph_utils::RemoveNodeOutputEdges(graph, node); + graph.RemoveNode(node.Index()); + if (p_reshape) { + graph_utils::FinalizeNodeFusion(graph, *p_reshape, sum_node); + } else { + graph_utils::RemoveNodeOutputEdges(graph, sum_node); + graph.RemoveNode(sum_node.Index()); + } + + modified = true; + } + + return Status::OK(); +} +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/sce_loss_grad_bias_fusion.h b/orttraining/orttraining/core/optimizer/sce_loss_grad_bias_fusion.h new file mode 100644 index 0000000000..29f93dfbea --- /dev/null +++ b/orttraining/orttraining/core/optimizer/sce_loss_grad_bias_fusion.h @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_transformer.h" + +namespace onnxruntime { + +/** +@Class SceLossGradBiasFusion +Fuse SoftmaxCrossEntropyLossInternalGrad + Reshape(optional) + Sum/Add to SoftmaxCrossEntropyLossInternalGrad. +If it's Sum Op, it requires that it has only 2 inputs. Sum/Add must be non-broadcasting computation. +*/ +class SceLossGradBiasFusion : public GraphTransformer { + public: + explicit SceLossGradBiasFusion(const InlinedHashSet& compatible_execution_providers = {}) noexcept + : GraphTransformer("SceLossGradBiasFusion", compatible_execution_providers) { + } + + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; +}; + +} // namespace onnxruntime diff --git a/orttraining/orttraining/test/optimizer/graph_transform_test.cc b/orttraining/orttraining/test/optimizer/graph_transform_test.cc index 01c12cbf96..dfd799d851 100644 --- a/orttraining/orttraining/test/optimizer/graph_transform_test.cc +++ b/orttraining/orttraining/test/optimizer/graph_transform_test.cc @@ -26,6 +26,7 @@ #include "orttraining/core/session/training_session.h" #include "orttraining/core/optimizer/loss_rewriter.h" #include "orttraining/core/optimizer/bias_softmax_dropout_fusion.h" +#include "orttraining/core/optimizer/sce_loss_grad_bias_fusion.h" #include @@ -281,6 +282,229 @@ TEST_F(GraphTransformationTests, BiasSoftmaxDropoutFusion) { RunBiasSoftmaxDropoutFusionTest(true, true, 14, *logger_); } +template +void RunSceLossGradBiasFusionTest(bool has_reshape, bool is_add_op, bool is_bias_lhs_input, bool has_weight, + bool has_ignore_index, const std::string& reduction, int opset_version, + const logging::Logger& logger) { + std::string bias_op_type = is_add_op ? "Add" : "Sum"; + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* dY_arg = builder.MakeInput({{}}); + auto* log_prob_arg = builder.MakeInput({{8, 2}}); + auto* index_arg = builder.MakeInput({{8}}); + std::vector scegrad_inputs{dY_arg, log_prob_arg, index_arg}; + if (has_weight || has_ignore_index) { + auto* weight_arg = builder.MakeInput({{2}}); + scegrad_inputs.emplace_back(weight_arg); + } + if (has_ignore_index) { + auto* ignore_index_arg = builder.MakeInput({{}}); + scegrad_inputs.emplace_back(ignore_index_arg); + } + auto* sce_grad_out = builder.MakeIntermediate(); + std::vector reshape_inputs; + std::vector reshape_outputs; + std::vector bias_op_inputs; + if (has_reshape) { + reshape_inputs.emplace_back(sce_grad_out); + auto* shape_arg = builder.MakeInput({{1}}); + reshape_inputs.emplace_back(shape_arg); + auto* reshape_out = builder.MakeIntermediate({{16}}); + reshape_outputs.emplace_back(reshape_out); + auto* bias_arg = builder.MakeInput({{16}}); + if (is_bias_lhs_input) { + bias_op_inputs.emplace_back(bias_arg); + bias_op_inputs.emplace_back(reshape_out); + } else { + bias_op_inputs.emplace_back(reshape_out); + bias_op_inputs.emplace_back(bias_arg); + } + } else { + auto* bias_arg = builder.MakeInput({{8, 2}}); + if (is_bias_lhs_input) { + bias_op_inputs.emplace_back(bias_arg); + bias_op_inputs.emplace_back(sce_grad_out); + } else { + bias_op_inputs.emplace_back(sce_grad_out); + bias_op_inputs.emplace_back(bias_arg); + } + } + auto* dx_out = builder.MakeOutput(); + + builder.AddNode("SoftmaxCrossEntropyLossInternalGrad", scegrad_inputs, {sce_grad_out}, kMSDomain) + .AddAttribute("reduction", reduction); + if (has_reshape) { + builder.AddNode("Reshape", reshape_inputs, reshape_outputs); + } + builder.AddNode(bias_op_type, bias_op_inputs, {dx_out}); + }; + + auto pre_graph_checker = [&](Graph& graph) { + ASSERT_EQ(CountOpsInGraph(graph)["com.microsoft.SoftmaxCrossEntropyLossInternalGrad"], 1); + ASSERT_EQ(CountOpsInGraph(graph)[bias_op_type], 1); + }; + + auto post_graph_checker = [&](Graph& graph) { + ASSERT_EQ(CountOpsInGraph(graph)["com.microsoft.SoftmaxCrossEntropyLossInternalGrad"], 1); + ASSERT_EQ(CountOpsInGraph(graph)[bias_op_type], 0); + for (auto& node : graph.Nodes()) { + if (node.OpType() == "SoftmaxCrossEntropyLossInternalGrad") { + auto& attrs = node.GetAttributes(); + ASSERT_TRUE(attrs.find("reduction") != attrs.end()); + ASSERT_EQ(reduction, attrs.at("reduction").s()); + ASSERT_EQ(6, static_cast(node.InputDefs().size())); + } + } + }; + + std::unique_ptr transformer = std::make_unique(); + TestGraphTransformer(build_test_case, opset_version, logger, std::move(transformer), TransformerLevel::Level2, 1, + pre_graph_checker, post_graph_checker); +} + +void RunSceLossGradBiasFusionTestWrapper(int opset_version, const logging::Logger& logger) { + RunSceLossGradBiasFusionTest(false, false, false, true, true, "none", opset_version, logger); + RunSceLossGradBiasFusionTest(false, false, true, true, false, "mean", opset_version, logger); + RunSceLossGradBiasFusionTest(false, false, false, false, false, "sum", opset_version, logger); + RunSceLossGradBiasFusionTest(false, true, true, true, true, "none", opset_version, logger); + RunSceLossGradBiasFusionTest(false, true, false, true, false, "mean", opset_version, logger); + RunSceLossGradBiasFusionTest(false, true, true, false, false, "sum", opset_version, logger); + RunSceLossGradBiasFusionTest(true, false, false, true, true, "none", opset_version, logger); + RunSceLossGradBiasFusionTest(true, false, true, true, false, "mean", opset_version, logger); + RunSceLossGradBiasFusionTest(true, false, false, false, false, "sum", opset_version, logger); + RunSceLossGradBiasFusionTest(true, true, true, true, true, "none", opset_version, logger); + RunSceLossGradBiasFusionTest(true, true, false, true, false, "mean", opset_version, logger); + RunSceLossGradBiasFusionTest(true, true, true, false, false, "sum", opset_version, logger); +} + +TEST_F(GraphTransformationTests, SceLossGradBiasFusion) { + RunSceLossGradBiasFusionTestWrapper(12, *logger_); + RunSceLossGradBiasFusionTestWrapper(13, *logger_); + RunSceLossGradBiasFusionTestWrapper(14, *logger_); +} + +TEST_F(GraphTransformationTests, SceLossGradBiasFusion_Invalid) { + auto pre_graph_checker = [&](Graph& graph) { + ASSERT_EQ(CountOpsInGraph(graph)["com.microsoft.SoftmaxCrossEntropyLossInternalGrad"], 1); + ASSERT_EQ(CountOpsInGraph(graph)["Sum"], 1); + }; + + auto post_graph_checker = [&](Graph& graph) { + ASSERT_EQ(CountOpsInGraph(graph)["com.microsoft.SoftmaxCrossEntropyLossInternalGrad"], 1); + ASSERT_EQ(CountOpsInGraph(graph)["Sum"], 1); + }; + + // Sum has more than 2 inputs. + { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* dY_arg = builder.MakeInput({{}}); + auto* log_prob_arg = builder.MakeInput({{8, 2}}); + auto* index_arg = builder.MakeInput({{8}}); + auto* sce_grad_out = builder.MakeIntermediate(); + auto* bias1_arg = builder.MakeInput({{8, 2}}); + auto* bias2_arg = builder.MakeInput({{8, 2}}); + auto* dx_out = builder.MakeOutput(); + builder + .AddNode("SoftmaxCrossEntropyLossInternalGrad", {dY_arg, log_prob_arg, index_arg}, {sce_grad_out}, kMSDomain) + .AddAttribute("reduction", "sum"); + builder.AddNode("Sum", {sce_grad_out, bias1_arg, bias2_arg}, {dx_out}); + }; + + std::unique_ptr transformer = std::make_unique(); + TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level2, 1, + pre_graph_checker, post_graph_checker); + } + + // SceGrad has more than 1 consumers. + { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* dY_arg = builder.MakeInput({{}}); + auto* log_prob_arg = builder.MakeInput({{8, 2}}); + auto* index_arg = builder.MakeInput({{8}}); + auto* sce_grad_out = builder.MakeIntermediate(); + auto* bias_arg = builder.MakeInput({{8, 2}}); + auto* dx_out = builder.MakeOutput(); + auto* identity_out = builder.MakeOutput(); + builder + .AddNode("SoftmaxCrossEntropyLossInternalGrad", {dY_arg, log_prob_arg, index_arg}, {sce_grad_out}, kMSDomain) + .AddAttribute("reduction", "sum"); + builder.AddNode("Sum", {sce_grad_out, bias_arg}, {dx_out}); + builder.AddNode("Identity", {sce_grad_out}, {identity_out}); + }; + + std::unique_ptr transformer = std::make_unique(); + TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level2, 1, + pre_graph_checker, post_graph_checker); + } + + // Sum inputs shape mismatch. + { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* dY_arg = builder.MakeInput({{}}); + auto* log_prob_arg = builder.MakeInput({{8, 2}}); + auto* index_arg = builder.MakeInput({{8}}); + auto* sce_grad_out = builder.MakeIntermediate(); + auto* bias_arg = builder.MakeInput({{2}}); + auto* dx_out = builder.MakeOutput(); + builder + .AddNode("SoftmaxCrossEntropyLossInternalGrad", {dY_arg, log_prob_arg, index_arg}, {sce_grad_out}, kMSDomain) + .AddAttribute("reduction", "sum"); + builder.AddNode("Sum", {sce_grad_out, bias_arg}, {dx_out}); + }; + + std::unique_ptr transformer = std::make_unique(); + TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level2, 1, + pre_graph_checker, post_graph_checker); + } + + // Sum inputs shape mismatch. + { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* dY_arg = builder.MakeInput({{}}); + auto* log_prob_arg = builder.MakeInput({{8, 1}}); + auto* index_arg = builder.MakeInput({{8}}); + auto* bias_arg = builder.MakeInput({{8, 1}}); + auto* sce_grad_out = builder.MakeIntermediate(); + auto* shape_arg = builder.MakeInput({{2}}); + auto* reshape_out = builder.MakeIntermediate({{1, 8}}); + auto* dx_out = builder.MakeOutput(); + builder + .AddNode("SoftmaxCrossEntropyLossInternalGrad", {dY_arg, log_prob_arg, index_arg}, {sce_grad_out}, kMSDomain) + .AddAttribute("reduction", "sum"); + builder.AddNode("Reshape", {sce_grad_out, shape_arg}, {reshape_out}); + builder.AddNode("Sum", {reshape_out, bias_arg}, {dx_out}); + }; + + std::unique_ptr transformer = std::make_unique(); + TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level2, 1, + pre_graph_checker, post_graph_checker); + } + + // Reshape output has more than 1 consumers. + { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* dY_arg = builder.MakeInput({{}}); + auto* log_prob_arg = builder.MakeInput({{8, 2}}); + auto* index_arg = builder.MakeInput({{8}}); + auto* bias_arg = builder.MakeInput({{16}}); + auto* sce_grad_out = builder.MakeIntermediate(); + auto* shape_arg = builder.MakeInput({{1}}); + auto* reshape_out = builder.MakeIntermediate({{16}}); + auto* dx_out = builder.MakeOutput(); + auto* identity_out = builder.MakeOutput(); + builder + .AddNode("SoftmaxCrossEntropyLossInternalGrad", {dY_arg, log_prob_arg, index_arg}, {sce_grad_out}, kMSDomain) + .AddAttribute("reduction", "sum"); + builder.AddNode("Reshape", {sce_grad_out, shape_arg}, {reshape_out}); + builder.AddNode("Sum", {reshape_out, bias_arg}, {dx_out}); + builder.AddNode("Identity", {reshape_out}, {identity_out}); + }; + + std::unique_ptr transformer = std::make_unique(); + TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level2, 1, + pre_graph_checker, post_graph_checker); + } +} + Node* GetNodeByName(Graph& graph, std::string node_name) { GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); diff --git a/orttraining/orttraining/test/training_ops/cuda/cross_entropy_test.cc b/orttraining/orttraining/test/training_ops/cuda/cross_entropy_test.cc index f5413488ad..d79fcb7db3 100644 --- a/orttraining/orttraining/test/training_ops/cuda/cross_entropy_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/cross_entropy_test.cc @@ -299,7 +299,7 @@ static void TestSoftmaxCrossEntropyLoss(CompareOpTester& test, const std::vector } if (test_fp16) { std::vector X_data_half(X_data.size()); - ConvertFloatToMLFloat16(X_data.data(), X_data_half.data(), int(X_data.size())); + ConvertFloatToMLFloat16(X_data.data(), X_data_half.data(), static_cast(X_data.size())); test.AddInput("X", *X_dims, X_data_half); } else { test.AddInput("X", *X_dims, X_data); @@ -311,7 +311,7 @@ static void TestSoftmaxCrossEntropyLoss(CompareOpTester& test, const std::vector std::vector weight_data = random.Uniform(*weight_dims, 0.0f, 1.0f); if (test_fp16) { std::vector weight_data_half(weight_data.size()); - ConvertFloatToMLFloat16(weight_data.data(), weight_data_half.data(), int(weight_data.size())); + ConvertFloatToMLFloat16(weight_data.data(), weight_data_half.data(), static_cast(weight_data.size())); test.AddInput("weight", *weight_dims, weight_data_half); } else { test.AddInput("weight", *weight_dims, weight_data); @@ -513,11 +513,11 @@ static void TestSoftmaxCrossEntropyLossGrad(const std::vector& dY_dims, } if (test_fp16) { std::vector dY_data_half(dY_data.size()); - ConvertFloatToMLFloat16(dY_data.data(), dY_data_half.data(), int(dY_data.size())); + ConvertFloatToMLFloat16(dY_data.data(), dY_data_half.data(), static_cast(dY_data.size())); test.AddInput("dY", dY_dims, dY_data_half); std::vector log_prob_data_half(log_prob_data.size()); - ConvertFloatToMLFloat16(log_prob_data.data(), log_prob_data_half.data(), int(log_prob_data.size())); + ConvertFloatToMLFloat16(log_prob_data.data(), log_prob_data_half.data(), static_cast(log_prob_data.size())); test.AddInput("log_prob", log_prob_dims, log_prob_data_half); test.AddInput("index", index_dims, index_data); @@ -614,7 +614,7 @@ static void TestSoftmaxCrossEntropyLossInternalGrad(const std::vector& const std::vector& weight_dims, const std::vector& dX_dims, const std::string& reduction, const std::int64_t ignore_index = -1, const bool test_fp16 = false, - const double error_tolerance = 1e-4) { + const double error_tolerance = 1e-4, const bool has_bias = false) { CompareOpTester test("SoftmaxCrossEntropyLossInternalGrad", 1, onnxruntime::kMSDomain); test.AddAttribute("reduction", reduction); @@ -630,23 +630,30 @@ static void TestSoftmaxCrossEntropyLossInternalGrad(const std::vector& std::vector weight_data = random.Uniform(weight_dims, 0.0f, 1.0f); if (test_fp16) { std::vector dY_data_half(dY_data.size()); - ConvertFloatToMLFloat16(dY_data.data(), dY_data_half.data(), int(dY_data.size())); + ConvertFloatToMLFloat16(dY_data.data(), dY_data_half.data(), static_cast(dY_data.size())); test.AddInput("dY", dY_dims, dY_data_half); std::vector log_prob_data_half(log_prob_data.size()); - ConvertFloatToMLFloat16(log_prob_data.data(), log_prob_data_half.data(), int(log_prob_data.size())); + ConvertFloatToMLFloat16(log_prob_data.data(), log_prob_data_half.data(), static_cast(log_prob_data.size())); test.AddInput("log_prob", log_prob_dims, log_prob_data_half); test.AddInput("index", index_dims, index_data); std::vector weight_data_half(weight_data.size()); - ConvertFloatToMLFloat16(weight_data.data(), weight_data_half.data(), int(weight_data.size())); + ConvertFloatToMLFloat16(weight_data.data(), weight_data_half.data(), static_cast(weight_data.size())); test.AddInput("weight", weight_dims, weight_data_half); - if (ignore_index != -1) { + if (ignore_index != -1 || has_bias) { test.AddInput("ignore_index", {}, &ignore_index, 1); } + if (has_bias) { + std::vector bias_data = random.Uniform(dX_dims, 0.0f, 1.0f); + std::vector bias_data_half(bias_data.size()); + ConvertFloatToMLFloat16(bias_data.data(), bias_data_half.data(), static_cast(bias_data.size())); + test.AddInput("bias", dX_dims, bias_data_half); + } + std::vector dX_data = FillZeros(dX_dims); test.AddOutput("dX", dX_dims, dX_data); @@ -656,10 +663,15 @@ static void TestSoftmaxCrossEntropyLossInternalGrad(const std::vector& test.AddInput("log_prob", log_prob_dims, log_prob_data); test.AddInput("index", index_dims, index_data); test.AddInput("weight", weight_dims, weight_data); - if (ignore_index != -1) { + if (ignore_index != -1 || has_bias) { test.AddInput("ignore_index", {}, &ignore_index, 1); } + if (has_bias) { + std::vector bias_data = random.Uniform(dX_dims, 0.0f, 1.0f); + test.AddInput("bias", dX_dims, bias_data); + } + std::vector dX_data = FillZeros(dX_dims); test.AddOutput("dX", dX_dims, dX_data); @@ -681,6 +693,20 @@ TEST(CudaKernelTest, SoftmaxCrossEntropyLossInternalGrad_TinySizeTensor) { TestSoftmaxCrossEntropyLossInternalGrad(dY_dims, log_prob_dims, index_dims, weight_dims, dX_dims, "mean", 0); TestSoftmaxCrossEntropyLossInternalGrad(dY_dims, log_prob_dims, index_dims, weight_dims, dX_dims, "sum", 0); TestSoftmaxCrossEntropyLossInternalGrad({8}, log_prob_dims, index_dims, weight_dims, dX_dims, "none", 0); + + // Bias. + TestSoftmaxCrossEntropyLossInternalGrad(dY_dims, log_prob_dims, index_dims, weight_dims, dX_dims, "mean", -1, false, + 1e-4, true); + TestSoftmaxCrossEntropyLossInternalGrad(dY_dims, log_prob_dims, index_dims, weight_dims, dX_dims, "sum", -1, false, + 1e-4, true); + TestSoftmaxCrossEntropyLossInternalGrad({8}, log_prob_dims, index_dims, weight_dims, dX_dims, "none", -1, false, 1e-4, + true); + TestSoftmaxCrossEntropyLossInternalGrad(dY_dims, log_prob_dims, index_dims, weight_dims, dX_dims, "mean", 0, false, + 1e-4, true); + TestSoftmaxCrossEntropyLossInternalGrad(dY_dims, log_prob_dims, index_dims, weight_dims, dX_dims, "sum", 0, false, + 1e-4, true); + TestSoftmaxCrossEntropyLossInternalGrad({8}, log_prob_dims, index_dims, weight_dims, dX_dims, "none", 0, false, 1e-4, + true); } TEST(CudaKernelTest, SoftmaxCrossEntropyLossInternalGrad_TinySizeTensor_half) { @@ -689,14 +715,32 @@ TEST(CudaKernelTest, SoftmaxCrossEntropyLossInternalGrad_TinySizeTensor_half) { std::vector index_dims{8}; std::vector weight_dims{2}; std::vector dX_dims{8, 2}; - TestSoftmaxCrossEntropyLossInternalGrad(dY_dims, log_prob_dims, index_dims, weight_dims, dX_dims, "mean", -1, true, 5e-2); - TestSoftmaxCrossEntropyLossInternalGrad(dY_dims, log_prob_dims, index_dims, weight_dims, dX_dims, "sum", -1, true, 5e-2); + TestSoftmaxCrossEntropyLossInternalGrad(dY_dims, log_prob_dims, index_dims, weight_dims, dX_dims, "mean", -1, true, + 5e-2); + TestSoftmaxCrossEntropyLossInternalGrad(dY_dims, log_prob_dims, index_dims, weight_dims, dX_dims, "sum", -1, true, + 5e-2); TestSoftmaxCrossEntropyLossInternalGrad({8}, log_prob_dims, index_dims, weight_dims, dX_dims, "none", -1, true, 5e-2); // Just test ignore_index for small tensor because it will increase test time a lot with little verification gain. - TestSoftmaxCrossEntropyLossInternalGrad(dY_dims, log_prob_dims, index_dims, weight_dims, dX_dims, "mean", 0, true, 5e-2); - TestSoftmaxCrossEntropyLossInternalGrad(dY_dims, log_prob_dims, index_dims, weight_dims, dX_dims, "sum", 0, true, 5e-2); + TestSoftmaxCrossEntropyLossInternalGrad(dY_dims, log_prob_dims, index_dims, weight_dims, dX_dims, "mean", 0, true, + 5e-2); + TestSoftmaxCrossEntropyLossInternalGrad(dY_dims, log_prob_dims, index_dims, weight_dims, dX_dims, "sum", 0, true, + 5e-2); TestSoftmaxCrossEntropyLossInternalGrad({8}, log_prob_dims, index_dims, weight_dims, dX_dims, "none", 0, true, 5e-2); + + // Bias. + TestSoftmaxCrossEntropyLossInternalGrad(dY_dims, log_prob_dims, index_dims, weight_dims, dX_dims, "mean", -1, true, + 5e-2, true); + TestSoftmaxCrossEntropyLossInternalGrad(dY_dims, log_prob_dims, index_dims, weight_dims, dX_dims, "sum", -1, true, + 5e-2, true); + TestSoftmaxCrossEntropyLossInternalGrad({8}, log_prob_dims, index_dims, weight_dims, dX_dims, "none", -1, true, 5e-2, + true); + TestSoftmaxCrossEntropyLossInternalGrad(dY_dims, log_prob_dims, index_dims, weight_dims, dX_dims, "mean", 0, true, + 5e-2, true); + TestSoftmaxCrossEntropyLossInternalGrad(dY_dims, log_prob_dims, index_dims, weight_dims, dX_dims, "sum", 0, true, + 5e-2, true); + TestSoftmaxCrossEntropyLossInternalGrad({8}, log_prob_dims, index_dims, weight_dims, dX_dims, "none", 0, true, 5e-2, + true); } } // namespace test diff --git a/orttraining/orttraining/training_ops/cpu/loss/softmax_cross_entropy_loss.cc b/orttraining/orttraining/training_ops/cpu/loss/softmax_cross_entropy_loss.cc index 73fddefc33..7b80ea651a 100644 --- a/orttraining/orttraining/training_ops/cpu/loss/softmax_cross_entropy_loss.cc +++ b/orttraining/orttraining/training_ops/cpu/loss/softmax_cross_entropy_loss.cc @@ -107,7 +107,7 @@ Status SoftmaxCrossEntropyLoss::Compute(OpKernelContext* context) const ORT_ENFORCE(p_ignore_index->Shape().IsScalar(), "ignore_index should be a scalar."); ignore_index = *(p_ignore_index->template Data()); } - + const TensorShape logit_shape{logit.Shape()}; const TensorShape label_shape{label.Shape()}; VerifyLogitWeightAndLabelShape(logit_shape, label_shape, p_weight ? &p_weight->Shape() : nullptr); @@ -386,6 +386,16 @@ Status SoftmaxCrossEntropyLossGrad::Compute(OpKernelContext* context) co d_logit->Reshape(new_shape); } + // Bias. + const Tensor* p_bias = context->Input(5); + if (p_bias) { + ORT_ENFORCE(probability_shape.Size() == p_bias->Shape().Size()); + const T1* bias_data = p_bias->Data(); + for (size_t i = 0; i < static_cast(probability_shape.Size()); ++i) { + d_logit_data[i] += bias_data[i]; + } + } + return Status::OK(); } @@ -403,4 +413,4 @@ REGISTER_KERNEL_INTERNAL_TYPED(SoftmaxCrossEntropyLossInternalGrad, SoftmaxCross REGISTER_KERNEL_INTERNAL_TYPED(SoftmaxCrossEntropyLossInternalGrad, SoftmaxCrossEntropyLossGrad, float, int64_t) } // namespace contrib -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/loss/softmax_cross_entropy_loss_impl.cc b/orttraining/orttraining/training_ops/cuda/loss/softmax_cross_entropy_loss_impl.cc index c4ec52e48a..dcc3e2fef0 100644 --- a/orttraining/orttraining/training_ops/cuda/loss/softmax_cross_entropy_loss_impl.cc +++ b/orttraining/orttraining/training_ops/cuda/loss/softmax_cross_entropy_loss_impl.cc @@ -197,10 +197,11 @@ Status SoftmaxCrossEntropyLossGrad::ComputeInternal(OpKernelContext* ctx const Tensor& label = *ctx->Input(2); const Tensor* p_weight = ctx->Input(3); const Tensor* p_ignore_index = ctx->Input(4); + const Tensor* p_bias = ctx->Input(5); int64_t ignore_index = ignore_index_; if (p_ignore_index) { ORT_ENFORCE(p_ignore_index->Shape().IsScalar(), "ignore_index should be a scalar."); - ignore_index = *(p_ignore_index->template Data()); + ignore_index = *(p_ignore_index->Data()); } const TensorShape probability_shape{log_prob.Shape()}; @@ -213,10 +214,10 @@ Status SoftmaxCrossEntropyLossGrad::ComputeInternal(OpKernelContext* ctx int64_t C; onnxruntime::contrib::GetNDCFromLogitAndLabelShape(probability_shape, label_shape, N_D, C); Tensor* d_logit = ctx->Output(0, probability_shape); - const T* dY_data = dY.template Data(); - const T* log_prob_data = log_prob.template Data(); - const Tin* label_data = label.template Data(); - T* d_logit_data = d_logit->template MutableData(); + const T* dY_data = dY.Data(); + const T* log_prob_data = log_prob.Data(); + const Tin* label_data = label.Data(); + T* d_logit_data = d_logit->MutableData(); const T* weight_data = nullptr; OrtValue transpose_output; TensorShapeVector new_shape; @@ -230,12 +231,12 @@ Status SoftmaxCrossEntropyLossGrad::ComputeInternal(OpKernelContext* ctx onnxruntime::contrib::GetPermutationAndShape(true, probability_shape, new_shape, permutations); transpose_output = AllocateTensorInMLValue(log_prob.DataType(), new_shape, alloc); ORT_RETURN_IF_ERROR(cuda::Transpose::DoTranspose(cuda::Transpose(info), permutations, log_prob, *transpose_output.GetMutable())); - log_prob_data = (*transpose_output.GetMutable()).template Data(); + log_prob_data = (*transpose_output.GetMutable()).Data(); } if (p_weight) { const Tensor& weight = *p_weight; - weight_data = weight.template Data(); + weight_data = weight.Data(); } IAllocatorUniquePtr weight_data_nd = GetScratchBuffer(N_D); @@ -267,12 +268,15 @@ Status SoftmaxCrossEntropyLossGrad::ComputeInternal(OpKernelContext* ctx CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(normalize_factor_data.get(), &normalize_factor, sizeof(TBuf), cudaMemcpyHostToDevice, Stream())); } + const T* bias_data = p_bias ? p_bias->Data() : nullptr; + SoftmaxCrossEntropyLossGradImpl(Stream(), reinterpret_cast(dY_data), reinterpret_cast(log_prob_data), label_data, reinterpret_cast(weight_data_nd_data), normalize_factor_data.get(), + reinterpret_cast(bias_data), N_D, C, ReductionType::NONE == reduction_, diff --git a/orttraining/orttraining/training_ops/cuda/loss/softmax_cross_entropy_loss_impl.cu b/orttraining/orttraining/training_ops/cuda/loss/softmax_cross_entropy_loss_impl.cu index f536054e7b..4e7171a072 100644 --- a/orttraining/orttraining/training_ops/cuda/loss/softmax_cross_entropy_loss_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/loss/softmax_cross_entropy_loss_impl.cu @@ -92,15 +92,17 @@ void SoftmaxCrossEntropyLossImpl(cudaStream_t stream, const T* log_prob, const T LaunchElementwiseKernel(stream, output_data, op, count); } -template +template struct OpWeightedSoftmaxCrossEntropyLossGrad { OpWeightedSoftmaxCrossEntropyLossGrad(const T* dY_data, const T* log_prob_data, const Tin* label_data, - const T* weight_data, const TAcc* normalize_factor_data, Tin C) + const T* weight_data, const TAcc* normalize_factor_data, const T* bias_data, + Tin C) : dY_data_(dY_data), log_prob_data_(log_prob_data), label_data_(label_data), weight_data_(weight_data), normalize_factor_data_(normalize_factor_data), + bias_data_(bias_data), C_(C) { C_fdm_ = fast_divmod(static_cast(C)); } @@ -108,15 +110,16 @@ struct OpWeightedSoftmaxCrossEntropyLossGrad { __device__ __inline__ T operator()(CUDA_LONG idx) const { // normalize_factor is sum of labels' weights. Because zero sum implies all weights are 0, the loss function should // be constant 0 and its corresponding gradient should be 0 as well. + T result = T(0.f); if (*normalize_factor_data_ != TAcc(0.f)) { int row, d; C_fdm_.divmod(idx, row, d); CUDA_KERNEL_ASSERT(weight_data_[row] == T(0.f) || (label_data_[row] >= 0 && label_data_[row] < C_)); - return static_cast(static_cast((IsReductionNone ? dY_data_[row] : *dY_data_) * weight_data_[row]) * - (_Exp(static_cast(log_prob_data_[idx])) - (TAcc)(d == label_data_[row])) / - (*normalize_factor_data_)); + result = static_cast(static_cast((IsReductionNone ? dY_data_[row] : *dY_data_) * weight_data_[row]) * + (_Exp(static_cast(log_prob_data_[idx])) - (TAcc)(d == label_data_[row])) / + (*normalize_factor_data_)); } - return T(0.f); + return HasBias ? result + bias_data_[idx] : result; } const T* dY_data_; @@ -124,23 +127,33 @@ struct OpWeightedSoftmaxCrossEntropyLossGrad { const Tin* label_data_; const T* weight_data_; const TAcc* normalize_factor_data_; + const T* bias_data_; Tin C_; fast_divmod C_fdm_; }; template void SoftmaxCrossEntropyLossGradImpl(cudaStream_t stream, const T* dY, const T* log_prob, const Tin* label, - const T* weight, const TAcc* normalize_factor, size_t count, size_t label_depth, - bool reduction_none, T* output_data) { + const T* weight, const TAcc* normalize_factor, const T* bias_data, size_t count, + size_t label_depth, bool reduction_none, T* output_data) { +#define LAUNCH_WEIGHTED_SOFTMAX_CROSS_ENTROPY_LOSS_GRAD_KERNEL(is_reduction_none, has_bias) \ + OpWeightedSoftmaxCrossEntropyLossGrad op( \ + dY, log_prob, label, weight, normalize_factor, bias_data, static_cast(label_depth)); \ + LaunchElementwiseKernel(stream, output_data, op, count * label_depth) if (reduction_none) { - OpWeightedSoftmaxCrossEntropyLossGrad op(dY, log_prob, label, weight, normalize_factor, - static_cast(label_depth)); - LaunchElementwiseKernel(stream, output_data, op, count * label_depth); + if (bias_data) { + LAUNCH_WEIGHTED_SOFTMAX_CROSS_ENTROPY_LOSS_GRAD_KERNEL(true, true); + } else { + LAUNCH_WEIGHTED_SOFTMAX_CROSS_ENTROPY_LOSS_GRAD_KERNEL(true, false); + } } else { - OpWeightedSoftmaxCrossEntropyLossGrad op(dY, log_prob, label, weight, normalize_factor, - static_cast(label_depth)); - LaunchElementwiseKernel(stream, output_data, op, count * label_depth); + if (bias_data) { + LAUNCH_WEIGHTED_SOFTMAX_CROSS_ENTROPY_LOSS_GRAD_KERNEL(false, true); + } else { + LAUNCH_WEIGHTED_SOFTMAX_CROSS_ENTROPY_LOSS_GRAD_KERNEL(false, false); + } } +#undef LAUNCH_WEIGHTED_SOFTMAX_CROSS_ENTROPY_LOSS_GRAD_KERNEL } #define INSTANTIATE_SCE_LOSS_IMPL(T, TAcc, Tin) \ @@ -148,8 +161,9 @@ void SoftmaxCrossEntropyLossGradImpl(cudaStream_t stream, const T* dY, const T* const TAcc* normalize_factor, size_t count, size_t label_depth, \ int64_t ignore_index, T* output_data); \ template void SoftmaxCrossEntropyLossGradImpl(cudaStream_t stream, const T* dY, const T* log_prob, const Tin* label, \ - const T* weight, const TAcc* normalize_factor, size_t count, \ - size_t label_depth, bool reducation_none, T* output_data) + const T* weight, const TAcc* normalize_factor, const T* bias_data, \ + size_t count, size_t label_depth, bool reducation_none, \ + T* output_data) INSTANTIATE_SCE_LOSS_IMPL(float, float, int32_t); INSTANTIATE_SCE_LOSS_IMPL(float, float, int64_t); diff --git a/orttraining/orttraining/training_ops/cuda/loss/softmax_cross_entropy_loss_impl.h b/orttraining/orttraining/training_ops/cuda/loss/softmax_cross_entropy_loss_impl.h index a9ba04f077..85b353cd56 100644 --- a/orttraining/orttraining/training_ops/cuda/loss/softmax_cross_entropy_loss_impl.h +++ b/orttraining/orttraining/training_ops/cuda/loss/softmax_cross_entropy_loss_impl.h @@ -30,6 +30,7 @@ void SoftmaxCrossEntropyLossGradImpl( const Tin* label, const T* weight, const TAcc* normalize_factor, + const T* bias_data, size_t count, size_t label_depth, bool reduction_none,