SoftmaxCrossEntropyLossInternalGrad and Sum Fusion (#12746)

* fuse scegrad and sum

* add yield output shapes to value_info

* resolve comments

* fix merge main
This commit is contained in:
Vincent Wang 2022-09-14 14:45:51 +08:00 committed by GitHub
parent 568950e28c
commit da07c83948
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 519 additions and 43 deletions

View file

@ -65,6 +65,7 @@
#ifdef ENABLE_TRAINING
#include "orttraining/core/optimizer/bitmask_dropout_replacement.h"
#include "orttraining/core/optimizer/bias_softmax_dropout_fusion.h"
#include "orttraining/core/optimizer/sce_loss_grad_bias_fusion.h"
#endif
#endif // !defined(ORT_MINIMAL_BUILD)
@ -256,6 +257,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
#ifdef ENABLE_TRAINING
transformers.emplace_back(std::make_unique<BitmaskDropoutReplacement>(cuda_rocm_eps));
transformers.emplace_back(std::make_unique<BiasSoftmaxDropoutFusion>(cuda_rocm_eps));
transformers.emplace_back(std::make_unique<SceLossGradBiasFusion>(cpu_cuda_rocm_eps));
#endif
transformers.emplace_back(std::make_unique<SkipLayerNormFusion>(cpu_cuda_rocm_eps));

View file

@ -96,6 +96,23 @@ class ModelTestBuilder {
return &graph_.GetOrCreateNodeArg(name, nullptr);
}
template <typename T>
NodeArg* MakeIntermediate(const std::optional<std::vector<int64_t>>& shape) {
ONNX_NAMESPACE::TypeProto type_proto;
type_proto.mutable_tensor_type()->set_elem_type(utils::ToTensorProtoElementType<T>());
if (shape != std::nullopt) {
type_proto.mutable_tensor_type()->mutable_shape();
for (auto& d : *shape) {
auto dim = type_proto.mutable_tensor_type()->mutable_shape()->add_dim();
if (d != -1) {
dim->set_dim_value(d);
}
}
}
std::string name = graph_.GenerateNodeArgName("node");
return &graph_.GetOrCreateNodeArg(name, &type_proto);
}
template <typename T>
NodeArg* MakeInitializer(const std::vector<int64_t>& shape, const std::vector<T>& data) {
std::string name = graph_.GenerateNodeArgName("constant");

View file

@ -448,6 +448,15 @@ void OrtModuleGraphBuilder::FindModuleOutputNeededForBackward() {
}
}
}
// Graph resolve will have the YieldOp outputs' shapes inferred. To avoid lossing these information when
// transferring model from backend to frontend (in case any graph optimization requires these shape information),
// add them to graph's ValueInfo.
for (const auto& node_def : yield_node->OutputDefs()) {
if (node_def->TypeAsProto()) {
gradient_graph.AddValueInfo(node_def);
}
}
}
void OrtModuleGraphBuilder::UpdatePythonOpInputsRequireGradInfo(

View file

@ -333,11 +333,23 @@ bool SCELossGradFunBuilder(bool ignore_index_as_attr, const FunctionBodyBuildCon
)");
builder.Add(R"(
adj_BCD = CastLike (one_hot_label_BCD, prob_BCD)
grad_BCD = Sub (prob_BCD, adj_BCD)
d_logits_BCD = Mul (d_loss_B1D, grad_BCD)
d_logits = Reshape (d_logits_BCD, orig_shape)
adj_BCD = CastLike (one_hot_label_BCD, prob_BCD)
grad_BCD = Sub (prob_BCD, adj_BCD)
d_logits_BCD = Mul (d_loss_B1D, grad_BCD)
)");
if (ctx.hasInput(5)) {
builder.Add(R"(
d_logits_without_bias = Reshape (d_logits_BCD, orig_shape)
bias_shaped = Reshape (bias, orig_shape)
d_logits = Add(d_logits_without_bias, bias_shaped)
)");
} else {
builder.Add(R"(
d_logits = Reshape (d_logits_BCD, orig_shape)
)");
}
schema.BuildFunction(functionProto);
return true;
};
@ -3892,6 +3904,7 @@ Return true if all elements are true and false otherwise.
.Input(4, "ignore_index",
"Scalar tensor to specify a target value that is ignored and does not contribute to the input gradient.",
"I", OpSchema::Optional)
.Input(5, "bias", "data to be non-broadcasting added to the gradient.", "T", OpSchema::Optional)
.Output(0, "d_logits", "gradient of logits", "T")
.TypeConstraint("T", {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
"Constrain to float, float16 and double tensors.")

View file

@ -0,0 +1,114 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "orttraining/core/optimizer/sce_loss_grad_bias_fusion.h"
#include "core/graph/graph_utils.h"
#include "core/optimizer/initializer.h"
#include "core/optimizer/utils.h"
namespace onnxruntime {
/**
Fuse SoftmaxCrossEntropyLossInternalGrad + Reshape(optional) + Sum/Add to SoftmaxCrossEntropyLossInternalGrad.
If it's Sum Op, it requires that it has only 2 inputs. Sum/Add must be non-broadcasting computation.
*/
Status SceLossGradBiasFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
const logging::Logger& logger) const {
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
for (auto node_index : node_topology_list) {
auto* node_ptr = graph.GetNode(node_index);
if (!node_ptr) continue; // Node was removed.
auto& node = *node_ptr;
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "SoftmaxCrossEntropyLossInternalGrad", {1}, kMSDomain) ||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) || node.InputDefs().size() == 6 ||
node.GetOutputEdgesCount() != 1) {
continue;
}
NodeArg* sce_grad_def = node.MutableOutputDefs()[0];
Node* p_next = graph.GetNode(node.OutputNodesBegin()->Index());
Node* p_reshape = nullptr;
if (graph_utils::IsSupportedOptypeVersionAndDomain(*p_next, "Reshape", {5, 13, 14}) &&
graph_utils::IsSupportedProvider(*p_next, GetCompatibleExecutionProviders()) &&
p_next->GetOutputEdgesCount() == 1) {
p_reshape = p_next;
sce_grad_def = p_reshape->MutableOutputDefs()[0];
p_next = graph.GetNode(p_next->OutputNodesBegin()->Index());
}
if (!(graph_utils::IsSupportedOptypeVersionAndDomain(*p_next, "Add", {7, 13, 14}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(*p_next, "Sum", {6, 8, 13})) ||
!graph_utils::IsSupportedProvider(*p_next, GetCompatibleExecutionProviders()) ||
p_next->InputDefs().size() != 2) {
continue;
}
Node& sum_node = *p_next;
auto& sum_input_defs = sum_node.MutableInputDefs();
auto shape0 = sum_input_defs[0]->Shape();
auto shape1 = sum_input_defs[1]->Shape();
if (!shape0 || !shape1 || shape0->dim_size() != shape1->dim_size()) {
continue;
}
bool has_same_shape = true;
for (int i = 0; i < shape0->dim_size(); ++i) {
if (shape0->dim(i) != shape1->dim(i)) {
has_same_shape = false;
break;
}
}
if (!has_same_shape) continue;
NodeArg* bias_def = sce_grad_def == sum_input_defs[0] ? sum_input_defs[1] : sum_input_defs[0];
auto& scegrad_inputs = node.MutableInputDefs();
InlinedVector<NodeArg*> new_scegrad_node_inputs{scegrad_inputs[0], scegrad_inputs[1], scegrad_inputs[2]};
InlinedVector<NodeArg*> new_scegrad_node_outputs;
if (scegrad_inputs.size() >= 4) {
new_scegrad_node_inputs.emplace_back(scegrad_inputs[3]);
} else {
new_scegrad_node_inputs.emplace_back(&graph.GetOrCreateNodeArg("", nullptr));
}
if (scegrad_inputs.size() >= 5) {
new_scegrad_node_inputs.emplace_back(scegrad_inputs[4]);
} else {
ONNX_NAMESPACE::TensorProto ignore_index_initializer_proto;
ignore_index_initializer_proto.set_name(graph.GenerateNodeArgName("sce_grad_ignore_index"));
ignore_index_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
ignore_index_initializer_proto.add_int64_data(static_cast<int64_t>(-1));
new_scegrad_node_inputs.emplace_back(&graph_utils::AddInitializer(graph, ignore_index_initializer_proto));
}
new_scegrad_node_inputs.emplace_back(bias_def);
if (!p_reshape) {
new_scegrad_node_outputs.emplace_back(sum_node.MutableOutputDefs()[0]);
} else {
new_scegrad_node_outputs.emplace_back(p_reshape->MutableInputDefs()[0]);
}
Node& new_scegrad_node =
graph.AddNode(graph.GenerateNodeName("FusedSoftmaxCrossEntropyLossInternalGrad"),
"SoftmaxCrossEntropyLossInternalGrad", "FusedSoftmaxCrossEntropyLossInternalGrad",
new_scegrad_node_inputs, new_scegrad_node_outputs, &node.GetAttributes(), kMSDomain);
new_scegrad_node.SetExecutionProviderType(node.GetExecutionProviderType());
graph_utils::RemoveNodeOutputEdges(graph, node);
graph.RemoveNode(node.Index());
if (p_reshape) {
graph_utils::FinalizeNodeFusion(graph, *p_reshape, sum_node);
} else {
graph_utils::RemoveNodeOutputEdges(graph, sum_node);
graph.RemoveNode(sum_node.Index());
}
modified = true;
}
return Status::OK();
}
} // namespace onnxruntime

View file

@ -0,0 +1,24 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/optimizer/graph_transformer.h"
namespace onnxruntime {
/**
@Class SceLossGradBiasFusion
Fuse SoftmaxCrossEntropyLossInternalGrad + Reshape(optional) + Sum/Add to SoftmaxCrossEntropyLossInternalGrad.
If it's Sum Op, it requires that it has only 2 inputs. Sum/Add must be non-broadcasting computation.
*/
class SceLossGradBiasFusion : public GraphTransformer {
public:
explicit SceLossGradBiasFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
: GraphTransformer("SceLossGradBiasFusion", compatible_execution_providers) {
}
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
};
} // namespace onnxruntime

View file

@ -26,6 +26,7 @@
#include "orttraining/core/session/training_session.h"
#include "orttraining/core/optimizer/loss_rewriter.h"
#include "orttraining/core/optimizer/bias_softmax_dropout_fusion.h"
#include "orttraining/core/optimizer/sce_loss_grad_bias_fusion.h"
#include <random>
@ -281,6 +282,229 @@ TEST_F(GraphTransformationTests, BiasSoftmaxDropoutFusion) {
RunBiasSoftmaxDropoutFusionTest<MLFloat16>(true, true, 14, *logger_);
}
template <typename T>
void RunSceLossGradBiasFusionTest(bool has_reshape, bool is_add_op, bool is_bias_lhs_input, bool has_weight,
bool has_ignore_index, const std::string& reduction, int opset_version,
const logging::Logger& logger) {
std::string bias_op_type = is_add_op ? "Add" : "Sum";
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* dY_arg = builder.MakeInput<T>({{}});
auto* log_prob_arg = builder.MakeInput<T>({{8, 2}});
auto* index_arg = builder.MakeInput<int64_t>({{8}});
std::vector<NodeArg*> scegrad_inputs{dY_arg, log_prob_arg, index_arg};
if (has_weight || has_ignore_index) {
auto* weight_arg = builder.MakeInput<T>({{2}});
scegrad_inputs.emplace_back(weight_arg);
}
if (has_ignore_index) {
auto* ignore_index_arg = builder.MakeInput<int64_t>({{}});
scegrad_inputs.emplace_back(ignore_index_arg);
}
auto* sce_grad_out = builder.MakeIntermediate();
std::vector<NodeArg*> reshape_inputs;
std::vector<NodeArg*> reshape_outputs;
std::vector<NodeArg*> bias_op_inputs;
if (has_reshape) {
reshape_inputs.emplace_back(sce_grad_out);
auto* shape_arg = builder.MakeInput<int64_t>({{1}});
reshape_inputs.emplace_back(shape_arg);
auto* reshape_out = builder.MakeIntermediate<T>({{16}});
reshape_outputs.emplace_back(reshape_out);
auto* bias_arg = builder.MakeInput<T>({{16}});
if (is_bias_lhs_input) {
bias_op_inputs.emplace_back(bias_arg);
bias_op_inputs.emplace_back(reshape_out);
} else {
bias_op_inputs.emplace_back(reshape_out);
bias_op_inputs.emplace_back(bias_arg);
}
} else {
auto* bias_arg = builder.MakeInput<T>({{8, 2}});
if (is_bias_lhs_input) {
bias_op_inputs.emplace_back(bias_arg);
bias_op_inputs.emplace_back(sce_grad_out);
} else {
bias_op_inputs.emplace_back(sce_grad_out);
bias_op_inputs.emplace_back(bias_arg);
}
}
auto* dx_out = builder.MakeOutput();
builder.AddNode("SoftmaxCrossEntropyLossInternalGrad", scegrad_inputs, {sce_grad_out}, kMSDomain)
.AddAttribute("reduction", reduction);
if (has_reshape) {
builder.AddNode("Reshape", reshape_inputs, reshape_outputs);
}
builder.AddNode(bias_op_type, bias_op_inputs, {dx_out});
};
auto pre_graph_checker = [&](Graph& graph) {
ASSERT_EQ(CountOpsInGraph(graph)["com.microsoft.SoftmaxCrossEntropyLossInternalGrad"], 1);
ASSERT_EQ(CountOpsInGraph(graph)[bias_op_type], 1);
};
auto post_graph_checker = [&](Graph& graph) {
ASSERT_EQ(CountOpsInGraph(graph)["com.microsoft.SoftmaxCrossEntropyLossInternalGrad"], 1);
ASSERT_EQ(CountOpsInGraph(graph)[bias_op_type], 0);
for (auto& node : graph.Nodes()) {
if (node.OpType() == "SoftmaxCrossEntropyLossInternalGrad") {
auto& attrs = node.GetAttributes();
ASSERT_TRUE(attrs.find("reduction") != attrs.end());
ASSERT_EQ(reduction, attrs.at("reduction").s());
ASSERT_EQ(6, static_cast<int>(node.InputDefs().size()));
}
}
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<SceLossGradBiasFusion>();
TestGraphTransformer(build_test_case, opset_version, logger, std::move(transformer), TransformerLevel::Level2, 1,
pre_graph_checker, post_graph_checker);
}
void RunSceLossGradBiasFusionTestWrapper(int opset_version, const logging::Logger& logger) {
RunSceLossGradBiasFusionTest<float>(false, false, false, true, true, "none", opset_version, logger);
RunSceLossGradBiasFusionTest<MLFloat16>(false, false, true, true, false, "mean", opset_version, logger);
RunSceLossGradBiasFusionTest<float>(false, false, false, false, false, "sum", opset_version, logger);
RunSceLossGradBiasFusionTest<MLFloat16>(false, true, true, true, true, "none", opset_version, logger);
RunSceLossGradBiasFusionTest<float>(false, true, false, true, false, "mean", opset_version, logger);
RunSceLossGradBiasFusionTest<MLFloat16>(false, true, true, false, false, "sum", opset_version, logger);
RunSceLossGradBiasFusionTest<float>(true, false, false, true, true, "none", opset_version, logger);
RunSceLossGradBiasFusionTest<MLFloat16>(true, false, true, true, false, "mean", opset_version, logger);
RunSceLossGradBiasFusionTest<float>(true, false, false, false, false, "sum", opset_version, logger);
RunSceLossGradBiasFusionTest<MLFloat16>(true, true, true, true, true, "none", opset_version, logger);
RunSceLossGradBiasFusionTest<float>(true, true, false, true, false, "mean", opset_version, logger);
RunSceLossGradBiasFusionTest<MLFloat16>(true, true, true, false, false, "sum", opset_version, logger);
}
TEST_F(GraphTransformationTests, SceLossGradBiasFusion) {
RunSceLossGradBiasFusionTestWrapper(12, *logger_);
RunSceLossGradBiasFusionTestWrapper(13, *logger_);
RunSceLossGradBiasFusionTestWrapper(14, *logger_);
}
TEST_F(GraphTransformationTests, SceLossGradBiasFusion_Invalid) {
auto pre_graph_checker = [&](Graph& graph) {
ASSERT_EQ(CountOpsInGraph(graph)["com.microsoft.SoftmaxCrossEntropyLossInternalGrad"], 1);
ASSERT_EQ(CountOpsInGraph(graph)["Sum"], 1);
};
auto post_graph_checker = [&](Graph& graph) {
ASSERT_EQ(CountOpsInGraph(graph)["com.microsoft.SoftmaxCrossEntropyLossInternalGrad"], 1);
ASSERT_EQ(CountOpsInGraph(graph)["Sum"], 1);
};
// Sum has more than 2 inputs.
{
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* dY_arg = builder.MakeInput<float>({{}});
auto* log_prob_arg = builder.MakeInput<float>({{8, 2}});
auto* index_arg = builder.MakeInput<int64_t>({{8}});
auto* sce_grad_out = builder.MakeIntermediate();
auto* bias1_arg = builder.MakeInput<float>({{8, 2}});
auto* bias2_arg = builder.MakeInput<float>({{8, 2}});
auto* dx_out = builder.MakeOutput();
builder
.AddNode("SoftmaxCrossEntropyLossInternalGrad", {dY_arg, log_prob_arg, index_arg}, {sce_grad_out}, kMSDomain)
.AddAttribute("reduction", "sum");
builder.AddNode("Sum", {sce_grad_out, bias1_arg, bias2_arg}, {dx_out});
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<SceLossGradBiasFusion>();
TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level2, 1,
pre_graph_checker, post_graph_checker);
}
// SceGrad has more than 1 consumers.
{
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* dY_arg = builder.MakeInput<float>({{}});
auto* log_prob_arg = builder.MakeInput<float>({{8, 2}});
auto* index_arg = builder.MakeInput<int64_t>({{8}});
auto* sce_grad_out = builder.MakeIntermediate();
auto* bias_arg = builder.MakeInput<float>({{8, 2}});
auto* dx_out = builder.MakeOutput();
auto* identity_out = builder.MakeOutput();
builder
.AddNode("SoftmaxCrossEntropyLossInternalGrad", {dY_arg, log_prob_arg, index_arg}, {sce_grad_out}, kMSDomain)
.AddAttribute("reduction", "sum");
builder.AddNode("Sum", {sce_grad_out, bias_arg}, {dx_out});
builder.AddNode("Identity", {sce_grad_out}, {identity_out});
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<SceLossGradBiasFusion>();
TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level2, 1,
pre_graph_checker, post_graph_checker);
}
// Sum inputs shape mismatch.
{
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* dY_arg = builder.MakeInput<float>({{}});
auto* log_prob_arg = builder.MakeInput<float>({{8, 2}});
auto* index_arg = builder.MakeInput<int64_t>({{8}});
auto* sce_grad_out = builder.MakeIntermediate();
auto* bias_arg = builder.MakeInput<float>({{2}});
auto* dx_out = builder.MakeOutput();
builder
.AddNode("SoftmaxCrossEntropyLossInternalGrad", {dY_arg, log_prob_arg, index_arg}, {sce_grad_out}, kMSDomain)
.AddAttribute("reduction", "sum");
builder.AddNode("Sum", {sce_grad_out, bias_arg}, {dx_out});
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<SceLossGradBiasFusion>();
TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level2, 1,
pre_graph_checker, post_graph_checker);
}
// Sum inputs shape mismatch.
{
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* dY_arg = builder.MakeInput<float>({{}});
auto* log_prob_arg = builder.MakeInput<float>({{8, 1}});
auto* index_arg = builder.MakeInput<int64_t>({{8}});
auto* bias_arg = builder.MakeInput<float>({{8, 1}});
auto* sce_grad_out = builder.MakeIntermediate();
auto* shape_arg = builder.MakeInput<int64_t>({{2}});
auto* reshape_out = builder.MakeIntermediate<float>({{1, 8}});
auto* dx_out = builder.MakeOutput();
builder
.AddNode("SoftmaxCrossEntropyLossInternalGrad", {dY_arg, log_prob_arg, index_arg}, {sce_grad_out}, kMSDomain)
.AddAttribute("reduction", "sum");
builder.AddNode("Reshape", {sce_grad_out, shape_arg}, {reshape_out});
builder.AddNode("Sum", {reshape_out, bias_arg}, {dx_out});
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<SceLossGradBiasFusion>();
TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level2, 1,
pre_graph_checker, post_graph_checker);
}
// Reshape output has more than 1 consumers.
{
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* dY_arg = builder.MakeInput<float>({{}});
auto* log_prob_arg = builder.MakeInput<float>({{8, 2}});
auto* index_arg = builder.MakeInput<int64_t>({{8}});
auto* bias_arg = builder.MakeInput<float>({{16}});
auto* sce_grad_out = builder.MakeIntermediate();
auto* shape_arg = builder.MakeInput<int64_t>({{1}});
auto* reshape_out = builder.MakeIntermediate<float>({{16}});
auto* dx_out = builder.MakeOutput();
auto* identity_out = builder.MakeOutput();
builder
.AddNode("SoftmaxCrossEntropyLossInternalGrad", {dY_arg, log_prob_arg, index_arg}, {sce_grad_out}, kMSDomain)
.AddAttribute("reduction", "sum");
builder.AddNode("Reshape", {sce_grad_out, shape_arg}, {reshape_out});
builder.AddNode("Sum", {reshape_out, bias_arg}, {dx_out});
builder.AddNode("Identity", {reshape_out}, {identity_out});
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<SceLossGradBiasFusion>();
TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level2, 1,
pre_graph_checker, post_graph_checker);
}
}
Node* GetNodeByName(Graph& graph, std::string node_name) {
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();

View file

@ -299,7 +299,7 @@ static void TestSoftmaxCrossEntropyLoss(CompareOpTester& test, const std::vector
}
if (test_fp16) {
std::vector<MLFloat16> X_data_half(X_data.size());
ConvertFloatToMLFloat16(X_data.data(), X_data_half.data(), int(X_data.size()));
ConvertFloatToMLFloat16(X_data.data(), X_data_half.data(), static_cast<int>(X_data.size()));
test.AddInput<MLFloat16>("X", *X_dims, X_data_half);
} else {
test.AddInput<float>("X", *X_dims, X_data);
@ -311,7 +311,7 @@ static void TestSoftmaxCrossEntropyLoss(CompareOpTester& test, const std::vector
std::vector<float> weight_data = random.Uniform<float>(*weight_dims, 0.0f, 1.0f);
if (test_fp16) {
std::vector<MLFloat16> weight_data_half(weight_data.size());
ConvertFloatToMLFloat16(weight_data.data(), weight_data_half.data(), int(weight_data.size()));
ConvertFloatToMLFloat16(weight_data.data(), weight_data_half.data(), static_cast<int>(weight_data.size()));
test.AddInput<MLFloat16>("weight", *weight_dims, weight_data_half);
} else {
test.AddInput<float>("weight", *weight_dims, weight_data);
@ -513,11 +513,11 @@ static void TestSoftmaxCrossEntropyLossGrad(const std::vector<int64_t>& dY_dims,
}
if (test_fp16) {
std::vector<MLFloat16> dY_data_half(dY_data.size());
ConvertFloatToMLFloat16(dY_data.data(), dY_data_half.data(), int(dY_data.size()));
ConvertFloatToMLFloat16(dY_data.data(), dY_data_half.data(), static_cast<int>(dY_data.size()));
test.AddInput<MLFloat16>("dY", dY_dims, dY_data_half);
std::vector<MLFloat16> log_prob_data_half(log_prob_data.size());
ConvertFloatToMLFloat16(log_prob_data.data(), log_prob_data_half.data(), int(log_prob_data.size()));
ConvertFloatToMLFloat16(log_prob_data.data(), log_prob_data_half.data(), static_cast<int>(log_prob_data.size()));
test.AddInput<MLFloat16>("log_prob", log_prob_dims, log_prob_data_half);
test.AddInput<int64_t>("index", index_dims, index_data);
@ -614,7 +614,7 @@ static void TestSoftmaxCrossEntropyLossInternalGrad(const std::vector<int64_t>&
const std::vector<int64_t>& weight_dims,
const std::vector<int64_t>& dX_dims, const std::string& reduction,
const std::int64_t ignore_index = -1, const bool test_fp16 = false,
const double error_tolerance = 1e-4) {
const double error_tolerance = 1e-4, const bool has_bias = false) {
CompareOpTester test("SoftmaxCrossEntropyLossInternalGrad", 1, onnxruntime::kMSDomain);
test.AddAttribute("reduction", reduction);
@ -630,23 +630,30 @@ static void TestSoftmaxCrossEntropyLossInternalGrad(const std::vector<int64_t>&
std::vector<float> weight_data = random.Uniform<float>(weight_dims, 0.0f, 1.0f);
if (test_fp16) {
std::vector<MLFloat16> dY_data_half(dY_data.size());
ConvertFloatToMLFloat16(dY_data.data(), dY_data_half.data(), int(dY_data.size()));
ConvertFloatToMLFloat16(dY_data.data(), dY_data_half.data(), static_cast<int>(dY_data.size()));
test.AddInput<MLFloat16>("dY", dY_dims, dY_data_half);
std::vector<MLFloat16> log_prob_data_half(log_prob_data.size());
ConvertFloatToMLFloat16(log_prob_data.data(), log_prob_data_half.data(), int(log_prob_data.size()));
ConvertFloatToMLFloat16(log_prob_data.data(), log_prob_data_half.data(), static_cast<int>(log_prob_data.size()));
test.AddInput<MLFloat16>("log_prob", log_prob_dims, log_prob_data_half);
test.AddInput<int64_t>("index", index_dims, index_data);
std::vector<MLFloat16> weight_data_half(weight_data.size());
ConvertFloatToMLFloat16(weight_data.data(), weight_data_half.data(), int(weight_data.size()));
ConvertFloatToMLFloat16(weight_data.data(), weight_data_half.data(), static_cast<int>(weight_data.size()));
test.AddInput<MLFloat16>("weight", weight_dims, weight_data_half);
if (ignore_index != -1) {
if (ignore_index != -1 || has_bias) {
test.AddInput<int64_t>("ignore_index", {}, &ignore_index, 1);
}
if (has_bias) {
std::vector<float> bias_data = random.Uniform<float>(dX_dims, 0.0f, 1.0f);
std::vector<MLFloat16> bias_data_half(bias_data.size());
ConvertFloatToMLFloat16(bias_data.data(), bias_data_half.data(), static_cast<int>(bias_data.size()));
test.AddInput<MLFloat16>("bias", dX_dims, bias_data_half);
}
std::vector<MLFloat16> dX_data = FillZeros<MLFloat16>(dX_dims);
test.AddOutput<MLFloat16>("dX", dX_dims, dX_data);
@ -656,10 +663,15 @@ static void TestSoftmaxCrossEntropyLossInternalGrad(const std::vector<int64_t>&
test.AddInput<float>("log_prob", log_prob_dims, log_prob_data);
test.AddInput<int64_t>("index", index_dims, index_data);
test.AddInput<float>("weight", weight_dims, weight_data);
if (ignore_index != -1) {
if (ignore_index != -1 || has_bias) {
test.AddInput<int64_t>("ignore_index", {}, &ignore_index, 1);
}
if (has_bias) {
std::vector<float> bias_data = random.Uniform<float>(dX_dims, 0.0f, 1.0f);
test.AddInput<float>("bias", dX_dims, bias_data);
}
std::vector<float> dX_data = FillZeros<float>(dX_dims);
test.AddOutput<float>("dX", dX_dims, dX_data);
@ -681,6 +693,20 @@ TEST(CudaKernelTest, SoftmaxCrossEntropyLossInternalGrad_TinySizeTensor) {
TestSoftmaxCrossEntropyLossInternalGrad(dY_dims, log_prob_dims, index_dims, weight_dims, dX_dims, "mean", 0);
TestSoftmaxCrossEntropyLossInternalGrad(dY_dims, log_prob_dims, index_dims, weight_dims, dX_dims, "sum", 0);
TestSoftmaxCrossEntropyLossInternalGrad({8}, log_prob_dims, index_dims, weight_dims, dX_dims, "none", 0);
// Bias.
TestSoftmaxCrossEntropyLossInternalGrad(dY_dims, log_prob_dims, index_dims, weight_dims, dX_dims, "mean", -1, false,
1e-4, true);
TestSoftmaxCrossEntropyLossInternalGrad(dY_dims, log_prob_dims, index_dims, weight_dims, dX_dims, "sum", -1, false,
1e-4, true);
TestSoftmaxCrossEntropyLossInternalGrad({8}, log_prob_dims, index_dims, weight_dims, dX_dims, "none", -1, false, 1e-4,
true);
TestSoftmaxCrossEntropyLossInternalGrad(dY_dims, log_prob_dims, index_dims, weight_dims, dX_dims, "mean", 0, false,
1e-4, true);
TestSoftmaxCrossEntropyLossInternalGrad(dY_dims, log_prob_dims, index_dims, weight_dims, dX_dims, "sum", 0, false,
1e-4, true);
TestSoftmaxCrossEntropyLossInternalGrad({8}, log_prob_dims, index_dims, weight_dims, dX_dims, "none", 0, false, 1e-4,
true);
}
TEST(CudaKernelTest, SoftmaxCrossEntropyLossInternalGrad_TinySizeTensor_half) {
@ -689,14 +715,32 @@ TEST(CudaKernelTest, SoftmaxCrossEntropyLossInternalGrad_TinySizeTensor_half) {
std::vector<int64_t> index_dims{8};
std::vector<int64_t> weight_dims{2};
std::vector<int64_t> dX_dims{8, 2};
TestSoftmaxCrossEntropyLossInternalGrad(dY_dims, log_prob_dims, index_dims, weight_dims, dX_dims, "mean", -1, true, 5e-2);
TestSoftmaxCrossEntropyLossInternalGrad(dY_dims, log_prob_dims, index_dims, weight_dims, dX_dims, "sum", -1, true, 5e-2);
TestSoftmaxCrossEntropyLossInternalGrad(dY_dims, log_prob_dims, index_dims, weight_dims, dX_dims, "mean", -1, true,
5e-2);
TestSoftmaxCrossEntropyLossInternalGrad(dY_dims, log_prob_dims, index_dims, weight_dims, dX_dims, "sum", -1, true,
5e-2);
TestSoftmaxCrossEntropyLossInternalGrad({8}, log_prob_dims, index_dims, weight_dims, dX_dims, "none", -1, true, 5e-2);
// Just test ignore_index for small tensor because it will increase test time a lot with little verification gain.
TestSoftmaxCrossEntropyLossInternalGrad(dY_dims, log_prob_dims, index_dims, weight_dims, dX_dims, "mean", 0, true, 5e-2);
TestSoftmaxCrossEntropyLossInternalGrad(dY_dims, log_prob_dims, index_dims, weight_dims, dX_dims, "sum", 0, true, 5e-2);
TestSoftmaxCrossEntropyLossInternalGrad(dY_dims, log_prob_dims, index_dims, weight_dims, dX_dims, "mean", 0, true,
5e-2);
TestSoftmaxCrossEntropyLossInternalGrad(dY_dims, log_prob_dims, index_dims, weight_dims, dX_dims, "sum", 0, true,
5e-2);
TestSoftmaxCrossEntropyLossInternalGrad({8}, log_prob_dims, index_dims, weight_dims, dX_dims, "none", 0, true, 5e-2);
// Bias.
TestSoftmaxCrossEntropyLossInternalGrad(dY_dims, log_prob_dims, index_dims, weight_dims, dX_dims, "mean", -1, true,
5e-2, true);
TestSoftmaxCrossEntropyLossInternalGrad(dY_dims, log_prob_dims, index_dims, weight_dims, dX_dims, "sum", -1, true,
5e-2, true);
TestSoftmaxCrossEntropyLossInternalGrad({8}, log_prob_dims, index_dims, weight_dims, dX_dims, "none", -1, true, 5e-2,
true);
TestSoftmaxCrossEntropyLossInternalGrad(dY_dims, log_prob_dims, index_dims, weight_dims, dX_dims, "mean", 0, true,
5e-2, true);
TestSoftmaxCrossEntropyLossInternalGrad(dY_dims, log_prob_dims, index_dims, weight_dims, dX_dims, "sum", 0, true,
5e-2, true);
TestSoftmaxCrossEntropyLossInternalGrad({8}, log_prob_dims, index_dims, weight_dims, dX_dims, "none", 0, true, 5e-2,
true);
}
} // namespace test

View file

@ -107,7 +107,7 @@ Status SoftmaxCrossEntropyLoss<T1, T2>::Compute(OpKernelContext* context) const
ORT_ENFORCE(p_ignore_index->Shape().IsScalar(), "ignore_index should be a scalar.");
ignore_index = *(p_ignore_index->template Data<int64_t>());
}
const TensorShape logit_shape{logit.Shape()};
const TensorShape label_shape{label.Shape()};
VerifyLogitWeightAndLabelShape(logit_shape, label_shape, p_weight ? &p_weight->Shape() : nullptr);
@ -386,6 +386,16 @@ Status SoftmaxCrossEntropyLossGrad<T1, T2>::Compute(OpKernelContext* context) co
d_logit->Reshape(new_shape);
}
// Bias.
const Tensor* p_bias = context->Input<Tensor>(5);
if (p_bias) {
ORT_ENFORCE(probability_shape.Size() == p_bias->Shape().Size());
const T1* bias_data = p_bias->Data<T1>();
for (size_t i = 0; i < static_cast<size_t>(probability_shape.Size()); ++i) {
d_logit_data[i] += bias_data[i];
}
}
return Status::OK();
}
@ -403,4 +413,4 @@ REGISTER_KERNEL_INTERNAL_TYPED(SoftmaxCrossEntropyLossInternalGrad, SoftmaxCross
REGISTER_KERNEL_INTERNAL_TYPED(SoftmaxCrossEntropyLossInternalGrad, SoftmaxCrossEntropyLossGrad, float, int64_t)
} // namespace contrib
} // namespace onnxruntime
} // namespace onnxruntime

View file

@ -197,10 +197,11 @@ Status SoftmaxCrossEntropyLossGrad<T, Tin>::ComputeInternal(OpKernelContext* ctx
const Tensor& label = *ctx->Input<Tensor>(2);
const Tensor* p_weight = ctx->Input<Tensor>(3);
const Tensor* p_ignore_index = ctx->Input<Tensor>(4);
const Tensor* p_bias = ctx->Input<Tensor>(5);
int64_t ignore_index = ignore_index_;
if (p_ignore_index) {
ORT_ENFORCE(p_ignore_index->Shape().IsScalar(), "ignore_index should be a scalar.");
ignore_index = *(p_ignore_index->template Data<int64_t>());
ignore_index = *(p_ignore_index->Data<int64_t>());
}
const TensorShape probability_shape{log_prob.Shape()};
@ -213,10 +214,10 @@ Status SoftmaxCrossEntropyLossGrad<T, Tin>::ComputeInternal(OpKernelContext* ctx
int64_t C;
onnxruntime::contrib::GetNDCFromLogitAndLabelShape(probability_shape, label_shape, N_D, C);
Tensor* d_logit = ctx->Output(0, probability_shape);
const T* dY_data = dY.template Data<T>();
const T* log_prob_data = log_prob.template Data<T>();
const Tin* label_data = label.template Data<Tin>();
T* d_logit_data = d_logit->template MutableData<T>();
const T* dY_data = dY.Data<T>();
const T* log_prob_data = log_prob.Data<T>();
const Tin* label_data = label.Data<Tin>();
T* d_logit_data = d_logit->MutableData<T>();
const T* weight_data = nullptr;
OrtValue transpose_output;
TensorShapeVector new_shape;
@ -230,12 +231,12 @@ Status SoftmaxCrossEntropyLossGrad<T, Tin>::ComputeInternal(OpKernelContext* ctx
onnxruntime::contrib::GetPermutationAndShape(true, probability_shape, new_shape, permutations);
transpose_output = AllocateTensorInMLValue(log_prob.DataType(), new_shape, alloc);
ORT_RETURN_IF_ERROR(cuda::Transpose::DoTranspose(cuda::Transpose(info), permutations, log_prob, *transpose_output.GetMutable<Tensor>()));
log_prob_data = (*transpose_output.GetMutable<Tensor>()).template Data<T>();
log_prob_data = (*transpose_output.GetMutable<Tensor>()).Data<T>();
}
if (p_weight) {
const Tensor& weight = *p_weight;
weight_data = weight.template Data<T>();
weight_data = weight.Data<T>();
}
IAllocatorUniquePtr<T> weight_data_nd = GetScratchBuffer<T>(N_D);
@ -267,12 +268,15 @@ Status SoftmaxCrossEntropyLossGrad<T, Tin>::ComputeInternal(OpKernelContext* ctx
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(normalize_factor_data.get(), &normalize_factor, sizeof(TBuf), cudaMemcpyHostToDevice, Stream()));
}
const T* bias_data = p_bias ? p_bias->Data<T>() : nullptr;
SoftmaxCrossEntropyLossGradImpl(Stream(),
reinterpret_cast<const CudaT*>(dY_data),
reinterpret_cast<const CudaT*>(log_prob_data),
label_data,
reinterpret_cast<const CudaT*>(weight_data_nd_data),
normalize_factor_data.get(),
reinterpret_cast<const CudaT*>(bias_data),
N_D,
C,
ReductionType::NONE == reduction_,

View file

@ -92,15 +92,17 @@ void SoftmaxCrossEntropyLossImpl(cudaStream_t stream, const T* log_prob, const T
LaunchElementwiseKernel<T, decltype(op)>(stream, output_data, op, count);
}
template <typename T, typename TAcc, typename Tin, bool IsReductionNone>
template <typename T, typename TAcc, typename Tin, bool IsReductionNone, bool HasBias>
struct OpWeightedSoftmaxCrossEntropyLossGrad {
OpWeightedSoftmaxCrossEntropyLossGrad(const T* dY_data, const T* log_prob_data, const Tin* label_data,
const T* weight_data, const TAcc* normalize_factor_data, Tin C)
const T* weight_data, const TAcc* normalize_factor_data, const T* bias_data,
Tin C)
: dY_data_(dY_data),
log_prob_data_(log_prob_data),
label_data_(label_data),
weight_data_(weight_data),
normalize_factor_data_(normalize_factor_data),
bias_data_(bias_data),
C_(C) {
C_fdm_ = fast_divmod(static_cast<int>(C));
}
@ -108,15 +110,16 @@ struct OpWeightedSoftmaxCrossEntropyLossGrad {
__device__ __inline__ T operator()(CUDA_LONG idx) const {
// normalize_factor is sum of labels' weights. Because zero sum implies all weights are 0, the loss function should
// be constant 0 and its corresponding gradient should be 0 as well.
T result = T(0.f);
if (*normalize_factor_data_ != TAcc(0.f)) {
int row, d;
C_fdm_.divmod(idx, row, d);
CUDA_KERNEL_ASSERT(weight_data_[row] == T(0.f) || (label_data_[row] >= 0 && label_data_[row] < C_));
return static_cast<T>(static_cast<TAcc>((IsReductionNone ? dY_data_[row] : *dY_data_) * weight_data_[row]) *
(_Exp(static_cast<TAcc>(log_prob_data_[idx])) - (TAcc)(d == label_data_[row])) /
(*normalize_factor_data_));
result = static_cast<T>(static_cast<TAcc>((IsReductionNone ? dY_data_[row] : *dY_data_) * weight_data_[row]) *
(_Exp(static_cast<TAcc>(log_prob_data_[idx])) - (TAcc)(d == label_data_[row])) /
(*normalize_factor_data_));
}
return T(0.f);
return HasBias ? result + bias_data_[idx] : result;
}
const T* dY_data_;
@ -124,23 +127,33 @@ struct OpWeightedSoftmaxCrossEntropyLossGrad {
const Tin* label_data_;
const T* weight_data_;
const TAcc* normalize_factor_data_;
const T* bias_data_;
Tin C_;
fast_divmod C_fdm_;
};
template <typename T, typename TAcc, typename Tin>
void SoftmaxCrossEntropyLossGradImpl(cudaStream_t stream, const T* dY, const T* log_prob, const Tin* label,
const T* weight, const TAcc* normalize_factor, size_t count, size_t label_depth,
bool reduction_none, T* output_data) {
const T* weight, const TAcc* normalize_factor, const T* bias_data, size_t count,
size_t label_depth, bool reduction_none, T* output_data) {
#define LAUNCH_WEIGHTED_SOFTMAX_CROSS_ENTROPY_LOSS_GRAD_KERNEL(is_reduction_none, has_bias) \
OpWeightedSoftmaxCrossEntropyLossGrad<T, TAcc, Tin, is_reduction_none, has_bias> op( \
dY, log_prob, label, weight, normalize_factor, bias_data, static_cast<Tin>(label_depth)); \
LaunchElementwiseKernel<T, decltype(op)>(stream, output_data, op, count * label_depth)
if (reduction_none) {
OpWeightedSoftmaxCrossEntropyLossGrad<T, TAcc, Tin, true> op(dY, log_prob, label, weight, normalize_factor,
static_cast<Tin>(label_depth));
LaunchElementwiseKernel<T, decltype(op)>(stream, output_data, op, count * label_depth);
if (bias_data) {
LAUNCH_WEIGHTED_SOFTMAX_CROSS_ENTROPY_LOSS_GRAD_KERNEL(true, true);
} else {
LAUNCH_WEIGHTED_SOFTMAX_CROSS_ENTROPY_LOSS_GRAD_KERNEL(true, false);
}
} else {
OpWeightedSoftmaxCrossEntropyLossGrad<T, TAcc, Tin, false> op(dY, log_prob, label, weight, normalize_factor,
static_cast<Tin>(label_depth));
LaunchElementwiseKernel<T, decltype(op)>(stream, output_data, op, count * label_depth);
if (bias_data) {
LAUNCH_WEIGHTED_SOFTMAX_CROSS_ENTROPY_LOSS_GRAD_KERNEL(false, true);
} else {
LAUNCH_WEIGHTED_SOFTMAX_CROSS_ENTROPY_LOSS_GRAD_KERNEL(false, false);
}
}
#undef LAUNCH_WEIGHTED_SOFTMAX_CROSS_ENTROPY_LOSS_GRAD_KERNEL
}
#define INSTANTIATE_SCE_LOSS_IMPL(T, TAcc, Tin) \
@ -148,8 +161,9 @@ void SoftmaxCrossEntropyLossGradImpl(cudaStream_t stream, const T* dY, const T*
const TAcc* normalize_factor, size_t count, size_t label_depth, \
int64_t ignore_index, T* output_data); \
template void SoftmaxCrossEntropyLossGradImpl(cudaStream_t stream, const T* dY, const T* log_prob, const Tin* label, \
const T* weight, const TAcc* normalize_factor, size_t count, \
size_t label_depth, bool reducation_none, T* output_data)
const T* weight, const TAcc* normalize_factor, const T* bias_data, \
size_t count, size_t label_depth, bool reducation_none, \
T* output_data)
INSTANTIATE_SCE_LOSS_IMPL(float, float, int32_t);
INSTANTIATE_SCE_LOSS_IMPL(float, float, int64_t);

View file

@ -30,6 +30,7 @@ void SoftmaxCrossEntropyLossGradImpl(
const Tin* label,
const T* weight,
const TAcc* normalize_factor,
const T* bias_data,
size_t count,
size_t label_depth,
bool reduction_none,