mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
f4edf9bb58
commit
08001d18ac
3 changed files with 416 additions and 403 deletions
|
|
@ -67,410 +67,422 @@ using OpsetToIgnorableIndicesMap = InlinedHashMap<int, IgnorableInputIndices>;
|
|||
* or not.
|
||||
* 3. Some ops are not supported in older opsets, we need to check whether it is applicable to recompute or not.
|
||||
*/
|
||||
InlinedHashMap<int, InlinedHashMap<std::string, OpsetToIgnorableIndicesMap>> InitializeRecomputableOpTable() {
|
||||
InlinedHashMap<int, InlinedHashMap<std::string, OpsetToIgnorableIndicesMap>> recomputable_op_table_map;
|
||||
|
||||
constexpr const int basic_op_level = static_cast<int>(ProbeLevel::Basic);
|
||||
recomputable_op_table_map.insert({basic_op_level, InlinedHashMap<std::string, OpsetToIgnorableIndicesMap>()});
|
||||
auto& basic_recomputable_op_table = recomputable_op_table_map.at(basic_op_level);
|
||||
|
||||
basic_recomputable_op_table.insert({
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Add", kOnnxDomain),
|
||||
{
|
||||
{1, {}},
|
||||
{6, {}},
|
||||
{7, {}},
|
||||
{13, {}},
|
||||
{14, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("BatchNormalization", kOnnxDomain),
|
||||
{
|
||||
{1, {}},
|
||||
{6, {}},
|
||||
{7, {}},
|
||||
{9, {}},
|
||||
{14, {}},
|
||||
{15, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("BiasGelu", kMSDomain),
|
||||
{
|
||||
{1, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("BiasDropout", kMSDomain),
|
||||
{
|
||||
{1, {3, 4}}, // ignore ratio (optional) and training mode (optional)
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("BitmaskBiasDropout", kMSDomain),
|
||||
{
|
||||
{1, {3, 4}}, // ignore ratio (optional) and training mode (optional)
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("BitmaskDropout", kMSDomain),
|
||||
{
|
||||
{1, {1, 2}}, // ignore ratio (optional) and training mode (optional)
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Cast", kOnnxDomain),
|
||||
{
|
||||
{1, {}},
|
||||
{6, {}},
|
||||
{9, {}},
|
||||
{13, {}},
|
||||
{19, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("ConcatTraining", kMSDomain),
|
||||
{
|
||||
{1, {}},
|
||||
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("ConstantOfShape", kOnnxDomain),
|
||||
{
|
||||
{9, {0}}, // ignore the `input`, e.g. the shape of the expected output tensor
|
||||
{20, {0}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Cos", kOnnxDomain),
|
||||
{
|
||||
{7, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("CumSum", kOnnxDomain),
|
||||
{
|
||||
// The axis input is trivial
|
||||
{11, {1}},
|
||||
{14, {1}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Dropout", kOnnxDomain),
|
||||
{
|
||||
// ONNX Dropout 1, 6, 7, 10 do not have seed attribute, so we remove them from the recompute support.
|
||||
{12, {1, 2}}, // ignore ratio and training_mode
|
||||
{13, {1, 2}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Div", kOnnxDomain),
|
||||
{
|
||||
{1, {}},
|
||||
{6, {}},
|
||||
{7, {}},
|
||||
{13, {}},
|
||||
{14, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Einsum", kOnnxDomain),
|
||||
{
|
||||
{12, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Equal", kOnnxDomain),
|
||||
{
|
||||
{1, {}},
|
||||
{7, {}},
|
||||
{11, {}},
|
||||
{13, {}},
|
||||
{19, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Expand", kOnnxDomain),
|
||||
{
|
||||
{8, {1}}, // Ignore the shape.
|
||||
{13, {1}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("FastGelu", kMSDomain),
|
||||
{
|
||||
{1, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("FlattenAndUnpad", kMSDomain),
|
||||
{
|
||||
{1, {1}}, // ignore the indices
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Gather", kOnnxDomain),
|
||||
{
|
||||
{1, {1}}, // ignore the indices
|
||||
{11, {1}},
|
||||
{13, {1}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Gelu", kOnnxDomain),
|
||||
{
|
||||
{20, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Gelu", kMSDomain),
|
||||
{
|
||||
{1, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Gemm", kOnnxDomain),
|
||||
{
|
||||
{1, {}},
|
||||
{6, {}},
|
||||
{7, {}},
|
||||
{9, {}},
|
||||
{11, {}},
|
||||
{13, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Less", kOnnxDomain),
|
||||
{
|
||||
{1, {}},
|
||||
{7, {}},
|
||||
{9, {}},
|
||||
{13, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("MemcpyFromHost", kOnnxDomain),
|
||||
{
|
||||
{1, {0}}, // Ignore CPU input.
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Mul", kOnnxDomain),
|
||||
{
|
||||
{1, {}},
|
||||
{6, {}},
|
||||
{7, {}},
|
||||
{13, {}},
|
||||
{14, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Neg", kOnnxDomain),
|
||||
{
|
||||
{1, {}},
|
||||
{6, {}},
|
||||
{13, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("NonZero", kOnnxDomain),
|
||||
{
|
||||
{9, {}},
|
||||
{13, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("PadAndUnflatten", kMSDomain),
|
||||
{
|
||||
{1, {1, 2}}, // ignore the indices and unflatten_dims
|
||||
},
|
||||
},
|
||||
{
|
||||
// Be noted, NOT all PythonOp will be allowed to recompute, there will be further check.
|
||||
utils::GetFullQualifiedOpName("PythonOp", kMSDomain),
|
||||
{
|
||||
{1, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Range", kOnnxDomain),
|
||||
{
|
||||
{11, {0, 1, 2}}, // ignore start, end, delta, because they are scalars.
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Reshape", kOnnxDomain),
|
||||
{
|
||||
{1, {}},
|
||||
{5, {}}, // ignore the shape.
|
||||
{13, {}},
|
||||
{14, {}},
|
||||
{19, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Sin", kOnnxDomain),
|
||||
{
|
||||
{7, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Slice", kOnnxDomain),
|
||||
{
|
||||
{1, {}},
|
||||
{10, {1, 2, 3, 4}}, // ignore starts, ends, axes (optional) and steps (optional)
|
||||
{11, {1, 2, 3, 4}},
|
||||
{13, {1, 2, 3, 4}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Split", kOnnxDomain),
|
||||
{
|
||||
{1, {1}}, // ignore split (optional)
|
||||
{2, {}},
|
||||
{11, {}},
|
||||
{13, {1}}, // ignore the split (optional)
|
||||
{18, {1}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Squeeze", kOnnxDomain),
|
||||
{
|
||||
{1, {}},
|
||||
{11, {}},
|
||||
{13, {1}}, // ignore the axes (optional)
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Sub", kOnnxDomain),
|
||||
{
|
||||
{1, {}},
|
||||
{6, {}},
|
||||
{7, {}},
|
||||
{13, {}},
|
||||
{14, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Tile", kOnnxDomain),
|
||||
{
|
||||
{1, {1, 2}},
|
||||
{6, {1}},
|
||||
{13, {1}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Transpose", kOnnxDomain),
|
||||
{
|
||||
{1, {}},
|
||||
{13, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Trilu", kOnnxDomain),
|
||||
{
|
||||
{14, {1}}, // ignore k (optional)
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("QuickGelu", kMSDomain),
|
||||
{
|
||||
{1, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Unsqueeze", kOnnxDomain),
|
||||
{
|
||||
{1, {}},
|
||||
{11, {}},
|
||||
{13, {1}}, // ignore the axes (optional)
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Where", kOnnxDomain),
|
||||
{
|
||||
{9, {}},
|
||||
{16, {}},
|
||||
},
|
||||
},
|
||||
|
||||
});
|
||||
|
||||
constexpr const int advanced_op_level = static_cast<int>(ProbeLevel::Advanced);
|
||||
recomputable_op_table_map.insert({advanced_op_level, InlinedHashMap<std::string, OpsetToIgnorableIndicesMap>()});
|
||||
auto& advanced_recomputable_op_table = recomputable_op_table_map.at(advanced_op_level);
|
||||
// Append basic_recomputable_op_table to advanced_recomputable_op_table.
|
||||
advanced_recomputable_op_table.insert(recomputable_op_table_map.at(basic_op_level).begin(),
|
||||
recomputable_op_table_map.at(basic_op_level).end());
|
||||
|
||||
advanced_recomputable_op_table.insert({
|
||||
{
|
||||
utils::GetFullQualifiedOpName("BiasSoftmax", kMSDomain),
|
||||
{
|
||||
{1, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("BiasSoftmaxDropout", kMSDomain),
|
||||
{
|
||||
{1, {2}}, // ignore ratio (optional)
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("LayerNormalization", kOnnxDomain),
|
||||
{
|
||||
// Opset 1 in ONNX official does not have LayerNormalization,
|
||||
// while our contrib op defined LayerNormalization in opset 1 in ONNX domain.
|
||||
{1, {}},
|
||||
{17, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("MatMul", kOnnxDomain),
|
||||
{
|
||||
{1, {}},
|
||||
{9, {}},
|
||||
{13, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("FusedMatMul", kMSDomain),
|
||||
{
|
||||
{1, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("SimplifiedLayerNormalization", kOnnxDomain),
|
||||
{
|
||||
// Opset 1 in ONNX official does not have SimplifiedLayerNormalization,
|
||||
// while our contrib op defined SimplifiedLayerNormalization in opset 1 in ONNX domain.
|
||||
{1, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("SkipLayerNormalization", kMSDomain),
|
||||
{
|
||||
{1, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("SkipSimplifiedLayerNormalization", kMSDomain),
|
||||
{
|
||||
{1, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Softmax", kOnnxDomain),
|
||||
{
|
||||
{1, {}},
|
||||
{11, {}},
|
||||
{13, {}},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
return recomputable_op_table_map;
|
||||
}
|
||||
|
||||
const InlinedHashMap<std::string, OpsetToIgnorableIndicesMap>& GetAllowedRecomputeOps(int probe_op_level) {
|
||||
static InlinedHashMap<int, InlinedHashMap<std::string, OpsetToIgnorableIndicesMap>> recomputable_op_table_map;
|
||||
if (recomputable_op_table_map.find(probe_op_level) != recomputable_op_table_map.end()) {
|
||||
return recomputable_op_table_map.at(probe_op_level);
|
||||
}
|
||||
static InlinedHashMap<int, InlinedHashMap<std::string, OpsetToIgnorableIndicesMap>>
|
||||
recomputable_op_table_map = InitializeRecomputableOpTable();
|
||||
|
||||
recomputable_op_table_map.insert({probe_op_level, InlinedHashMap<std::string, OpsetToIgnorableIndicesMap>()});
|
||||
auto& recomputable_op_table = recomputable_op_table_map.at(probe_op_level);
|
||||
if (probe_op_level >= static_cast<int>(ProbeLevel::Basic)) {
|
||||
recomputable_op_table.insert({
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Add", kOnnxDomain),
|
||||
{
|
||||
{1, {}},
|
||||
{6, {}},
|
||||
{7, {}},
|
||||
{13, {}},
|
||||
{14, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("BatchNormalization", kOnnxDomain),
|
||||
{
|
||||
{1, {}},
|
||||
{6, {}},
|
||||
{7, {}},
|
||||
{9, {}},
|
||||
{14, {}},
|
||||
{15, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("BiasGelu", kMSDomain),
|
||||
{
|
||||
{1, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("BiasDropout", kMSDomain),
|
||||
{
|
||||
{1, {3, 4}}, // ignore ratio (optional) and training mode (optional)
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("BitmaskBiasDropout", kMSDomain),
|
||||
{
|
||||
{1, {3, 4}}, // ignore ratio (optional) and training mode (optional)
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("BitmaskDropout", kMSDomain),
|
||||
{
|
||||
{1, {1, 2}}, // ignore ratio (optional) and training mode (optional)
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Cast", kOnnxDomain),
|
||||
{
|
||||
{1, {}},
|
||||
{6, {}},
|
||||
{9, {}},
|
||||
{13, {}},
|
||||
{19, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("ConcatTraining", kMSDomain),
|
||||
{
|
||||
{1, {}},
|
||||
ORT_ENFORCE(recomputable_op_table_map.find(probe_op_level) != recomputable_op_table_map.end(),
|
||||
"Cannot get recomputable op table, probe level: ", probe_op_level);
|
||||
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("ConstantOfShape", kOnnxDomain),
|
||||
{
|
||||
{9, {0}}, // ignore the `input`, e.g. the shape of the expected output tensor
|
||||
{20, {0}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Cos", kOnnxDomain),
|
||||
{
|
||||
{7, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("CumSum", kOnnxDomain),
|
||||
{
|
||||
// The axis input is trivial
|
||||
{11, {1}},
|
||||
{14, {1}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Dropout", kOnnxDomain),
|
||||
{
|
||||
// ONNX Dropout 1, 6, 7, 10 do not have seed attribute, so we remove them from the recompute support.
|
||||
{12, {1, 2}}, // ignore ratio and training_mode
|
||||
{13, {1, 2}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Div", kOnnxDomain),
|
||||
{
|
||||
{1, {}},
|
||||
{6, {}},
|
||||
{7, {}},
|
||||
{13, {}},
|
||||
{14, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Einsum", kOnnxDomain),
|
||||
{
|
||||
{12, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Equal", kOnnxDomain),
|
||||
{
|
||||
{1, {}},
|
||||
{7, {}},
|
||||
{11, {}},
|
||||
{13, {}},
|
||||
{19, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Expand", kOnnxDomain),
|
||||
{
|
||||
{8, {1}}, // Ignore the shape.
|
||||
{13, {1}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("FastGelu", kMSDomain),
|
||||
{
|
||||
{1, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("FlattenAndUnpad", kMSDomain),
|
||||
{
|
||||
{1, {1}}, // ignore the indices
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Gather", kOnnxDomain),
|
||||
{
|
||||
{1, {1}}, // ignore the indices
|
||||
{11, {1}},
|
||||
{13, {1}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Gelu", kOnnxDomain),
|
||||
{
|
||||
{20, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Gelu", kMSDomain),
|
||||
{
|
||||
{1, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Gemm", kOnnxDomain),
|
||||
{
|
||||
{1, {}},
|
||||
{6, {}},
|
||||
{7, {}},
|
||||
{9, {}},
|
||||
{11, {}},
|
||||
{13, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Less", kOnnxDomain),
|
||||
{
|
||||
{1, {}},
|
||||
{7, {}},
|
||||
{9, {}},
|
||||
{13, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("MemcpyFromHost", kOnnxDomain),
|
||||
{
|
||||
{1, {0}}, // Ignore CPU input.
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Mul", kOnnxDomain),
|
||||
{
|
||||
{1, {}},
|
||||
{6, {}},
|
||||
{7, {}},
|
||||
{13, {}},
|
||||
{14, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Neg", kOnnxDomain),
|
||||
{
|
||||
{1, {}},
|
||||
{6, {}},
|
||||
{13, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("NonZero", kOnnxDomain),
|
||||
{
|
||||
{9, {}},
|
||||
{13, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("PadAndUnflatten", kMSDomain),
|
||||
{
|
||||
{1, {1, 2}}, // ignore the indices and unflatten_dims
|
||||
},
|
||||
},
|
||||
{
|
||||
// Be noted, NOT all PythonOp will be allowed to recompute, there will be further check.
|
||||
utils::GetFullQualifiedOpName("PythonOp", kMSDomain),
|
||||
{
|
||||
{1, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Range", kOnnxDomain),
|
||||
{
|
||||
{11, {0, 1, 2}}, // ignore start, end, delta, because they are scalars.
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Reshape", kOnnxDomain),
|
||||
{
|
||||
{1, {}},
|
||||
{5, {}}, // ignore the shape.
|
||||
{13, {}},
|
||||
{14, {}},
|
||||
{19, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Sin", kOnnxDomain),
|
||||
{
|
||||
{7, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Slice", kOnnxDomain),
|
||||
{
|
||||
{1, {}},
|
||||
{10, {1, 2, 3, 4}}, // ignore starts, ends, axes (optional) and steps (optional)
|
||||
{11, {1, 2, 3, 4}},
|
||||
{13, {1, 2, 3, 4}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Split", kOnnxDomain),
|
||||
{
|
||||
{1, {1}}, // ignore split (optional)
|
||||
{2, {}},
|
||||
{11, {}},
|
||||
{13, {1}}, // ignore the split (optional)
|
||||
{18, {1}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Squeeze", kOnnxDomain),
|
||||
{
|
||||
{1, {}},
|
||||
{11, {}},
|
||||
{13, {1}}, // ignore the axes (optional)
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Sub", kOnnxDomain),
|
||||
{
|
||||
{1, {}},
|
||||
{6, {}},
|
||||
{7, {}},
|
||||
{13, {}},
|
||||
{14, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Tile", kOnnxDomain),
|
||||
{
|
||||
{1, {1, 2}},
|
||||
{6, {1}},
|
||||
{13, {1}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Transpose", kOnnxDomain),
|
||||
{
|
||||
{1, {}},
|
||||
{13, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Trilu", kOnnxDomain),
|
||||
{
|
||||
{14, {1}}, // ignore k (optional)
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("QuickGelu", kMSDomain),
|
||||
{
|
||||
{1, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Unsqueeze", kOnnxDomain),
|
||||
{
|
||||
{1, {}},
|
||||
{11, {}},
|
||||
{13, {1}}, // ignore the axes (optional)
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Where", kOnnxDomain),
|
||||
{
|
||||
{9, {}},
|
||||
{16, {}},
|
||||
},
|
||||
},
|
||||
|
||||
});
|
||||
}
|
||||
|
||||
if (probe_op_level >= static_cast<int>(ProbeLevel::Advanced)) {
|
||||
recomputable_op_table.insert({
|
||||
{
|
||||
utils::GetFullQualifiedOpName("BiasSoftmax", kMSDomain),
|
||||
{
|
||||
{1, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("BiasSoftmaxDropout", kMSDomain),
|
||||
{
|
||||
{1, {2}}, // ignore ratio (optional)
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("LayerNormalization", kOnnxDomain),
|
||||
{
|
||||
// Opset 1 in ONNX official does not have LayerNormalization,
|
||||
// while our contrib op defined LayerNormalization in opset 1 in ONNX domain.
|
||||
{1, {}},
|
||||
{17, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("MatMul", kOnnxDomain),
|
||||
{
|
||||
{1, {}},
|
||||
{9, {}},
|
||||
{13, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("FusedMatMul", kMSDomain),
|
||||
{
|
||||
{1, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("SimplifiedLayerNormalization", kOnnxDomain),
|
||||
{
|
||||
// Opset 1 in ONNX official does not have SimplifiedLayerNormalization,
|
||||
// while our contrib op defined SimplifiedLayerNormalization in opset 1 in ONNX domain.
|
||||
{1, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("SkipLayerNormalization", kMSDomain),
|
||||
{
|
||||
{1, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("SkipSimplifiedLayerNormalization", kMSDomain),
|
||||
{
|
||||
{1, {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
utils::GetFullQualifiedOpName("Softmax", kOnnxDomain),
|
||||
{
|
||||
{1, {}},
|
||||
{11, {}},
|
||||
{13, {}},
|
||||
},
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
return recomputable_op_table;
|
||||
return recomputable_op_table_map.at(probe_op_level);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -542,8 +542,9 @@ TEST(TrainingApiTest, OptimStep) {
|
|||
std::string param_name = "fc2.weight";
|
||||
// before training, check if optim state is initialized to 0
|
||||
onnxruntime::training::api::OptimizerCheckpointState& optimizer_states = state.optimizer_checkpoint_state;
|
||||
std::shared_ptr<onnxruntime::training::api::GroupOptimizerState> group0_states = optimizer_states.group_named_optimizer_states["group0"];
|
||||
onnxruntime::training::api::ParameterOptimizerState& param_state =
|
||||
optimizer_states.group_named_optimizer_states["group0"]->param_named_optimizer_states.at(param_name);
|
||||
group0_states->param_named_optimizer_states.at(param_name);
|
||||
OrtValue& moment_1 = param_state.at("momentum0");
|
||||
|
||||
std::vector<float> param_vec_before_optimizer_step;
|
||||
|
|
|
|||
|
|
@ -449,7 +449,7 @@ Status FromOptimizerState(const OptimizerCheckpointState& optimizer_state,
|
|||
|
||||
fbs_optimizer_groups.reserve(optimizer_state.group_named_optimizer_states.size());
|
||||
for (const auto& group_name : SortedKeys(optimizer_state.group_named_optimizer_states)) {
|
||||
const std::shared_ptr<GroupOptimizerState>& group_optimizer_state_ptr =
|
||||
std::shared_ptr<GroupOptimizerState> group_optimizer_state_ptr =
|
||||
optimizer_state.group_named_optimizer_states.at(group_name);
|
||||
|
||||
std::vector<flatbuffers::Offset<fbs::ParameterOptimizerState>> optimizer_states;
|
||||
|
|
|
|||
Loading…
Reference in a new issue