diff --git a/orttraining/orttraining/core/graph/mixed_precision_transformer.cc b/orttraining/orttraining/core/graph/mixed_precision_transformer.cc index 90bfa6a32b..e79db74a92 100644 --- a/orttraining/orttraining/core/graph/mixed_precision_transformer.cc +++ b/orttraining/orttraining/core/graph/mixed_precision_transformer.cc @@ -33,7 +33,9 @@ namespace training { // continue to use 32-bit precision. Others will used reduced precision. static const std::unordered_set FP32_Nodes = { "SparseSoftmaxCrossEntropy", - "SparseSoftmaxCrossEntropyGrad"}; + "SparseSoftmaxCrossEntropyGrad", + "SoftmaxCrossEntropyLoss", + "SoftmaxCrossEntropyLossGrad"}; bool IsFP32Node(const Node* node) { return FP32_Nodes.find(node->OpType()) != FP32_Nodes.cend(); @@ -54,6 +56,8 @@ static const std::unordered_map> stage2_fp32_node_ {"DropoutGrad", {2}}, {"SparseSoftmaxCrossEntropy", {0, 2}}, {"SparseSoftmaxCrossEntropyGrad", {0, 1, 3}}, + {"SoftmaxCrossEntropyLoss", {0, 2}}, + {"SoftmaxCrossEntropyLossGrad", {0, 1, 3}}, }; bool IsFP32(const std::unordered_map>& map, std::string opname, int argnum) {