From b9a5ed1fe2b2fc58aab706a1967a6d659de70879 Mon Sep 17 00:00:00 2001 From: "M. Zeeshan Siddiqui" Date: Thu, 30 Apr 2020 02:48:21 -0700 Subject: [PATCH] Add SoftmaxCrossEntropyLoss to mixed-precision-transformer. (#3760) --- .../orttraining/core/graph/mixed_precision_transformer.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/orttraining/orttraining/core/graph/mixed_precision_transformer.cc b/orttraining/orttraining/core/graph/mixed_precision_transformer.cc index 90bfa6a32b..e79db74a92 100644 --- a/orttraining/orttraining/core/graph/mixed_precision_transformer.cc +++ b/orttraining/orttraining/core/graph/mixed_precision_transformer.cc @@ -33,7 +33,9 @@ namespace training { // continue to use 32-bit precision. Others will used reduced precision. static const std::unordered_set FP32_Nodes = { "SparseSoftmaxCrossEntropy", - "SparseSoftmaxCrossEntropyGrad"}; + "SparseSoftmaxCrossEntropyGrad", + "SoftmaxCrossEntropyLoss", + "SoftmaxCrossEntropyLossGrad"}; bool IsFP32Node(const Node* node) { return FP32_Nodes.find(node->OpType()) != FP32_Nodes.cend(); @@ -54,6 +56,8 @@ static const std::unordered_map> stage2_fp32_node_ {"DropoutGrad", {2}}, {"SparseSoftmaxCrossEntropy", {0, 2}}, {"SparseSoftmaxCrossEntropyGrad", {0, 1, 3}}, + {"SoftmaxCrossEntropyLoss", {0, 2}}, + {"SoftmaxCrossEntropyLossGrad", {0, 1, 3}}, }; bool IsFP32(const std::unordered_map>& map, std::string opname, int argnum) {