mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
Add SoftmaxCrossEntropyLoss to mixed-precision-transformer. (#3760)
This commit is contained in:
parent
9f72752397
commit
b9a5ed1fe2
1 changed files with 5 additions and 1 deletions
|
|
@ -33,7 +33,9 @@ namespace training {
|
|||
// continue to use 32-bit precision. Others will used reduced precision.
|
||||
static const std::unordered_set<std::string> FP32_Nodes = {
|
||||
"SparseSoftmaxCrossEntropy",
|
||||
"SparseSoftmaxCrossEntropyGrad"};
|
||||
"SparseSoftmaxCrossEntropyGrad",
|
||||
"SoftmaxCrossEntropyLoss",
|
||||
"SoftmaxCrossEntropyLossGrad"};
|
||||
|
||||
bool IsFP32Node(const Node* node) {
|
||||
return FP32_Nodes.find(node->OpType()) != FP32_Nodes.cend();
|
||||
|
|
@ -54,6 +56,8 @@ static const std::unordered_map<std::string, std::vector<int>> stage2_fp32_node_
|
|||
{"DropoutGrad", {2}},
|
||||
{"SparseSoftmaxCrossEntropy", {0, 2}},
|
||||
{"SparseSoftmaxCrossEntropyGrad", {0, 1, 3}},
|
||||
{"SoftmaxCrossEntropyLoss", {0, 2}},
|
||||
{"SoftmaxCrossEntropyLossGrad", {0, 1, 3}},
|
||||
};
|
||||
|
||||
bool IsFP32(const std::unordered_map<std::string, std::vector<int>>& map, std::string opname, int argnum) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue