### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
pengwa 2024-07-25 08:25:22 +08:00 committed by GitHub
parent f4edf9bb58
commit 08001d18ac
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 416 additions and 403 deletions

View file

@ -67,410 +67,422 @@ using OpsetToIgnorableIndicesMap = InlinedHashMap<int, IgnorableInputIndices>;
* or not.
* 3. Some ops are not supported in older opsets, we need to check whether it is applicable to recompute or not.
*/
InlinedHashMap<int, InlinedHashMap<std::string, OpsetToIgnorableIndicesMap>> InitializeRecomputableOpTable() {
InlinedHashMap<int, InlinedHashMap<std::string, OpsetToIgnorableIndicesMap>> recomputable_op_table_map;
constexpr const int basic_op_level = static_cast<int>(ProbeLevel::Basic);
recomputable_op_table_map.insert({basic_op_level, InlinedHashMap<std::string, OpsetToIgnorableIndicesMap>()});
auto& basic_recomputable_op_table = recomputable_op_table_map.at(basic_op_level);
basic_recomputable_op_table.insert({
{
utils::GetFullQualifiedOpName("Add", kOnnxDomain),
{
{1, {}},
{6, {}},
{7, {}},
{13, {}},
{14, {}},
},
},
{
utils::GetFullQualifiedOpName("BatchNormalization", kOnnxDomain),
{
{1, {}},
{6, {}},
{7, {}},
{9, {}},
{14, {}},
{15, {}},
},
},
{
utils::GetFullQualifiedOpName("BiasGelu", kMSDomain),
{
{1, {}},
},
},
{
utils::GetFullQualifiedOpName("BiasDropout", kMSDomain),
{
{1, {3, 4}}, // ignore ratio (optional) and training mode (optional)
},
},
{
utils::GetFullQualifiedOpName("BitmaskBiasDropout", kMSDomain),
{
{1, {3, 4}}, // ignore ratio (optional) and training mode (optional)
},
},
{
utils::GetFullQualifiedOpName("BitmaskDropout", kMSDomain),
{
{1, {1, 2}}, // ignore ratio (optional) and training mode (optional)
},
},
{
utils::GetFullQualifiedOpName("Cast", kOnnxDomain),
{
{1, {}},
{6, {}},
{9, {}},
{13, {}},
{19, {}},
},
},
{
utils::GetFullQualifiedOpName("ConcatTraining", kMSDomain),
{
{1, {}},
},
},
{
utils::GetFullQualifiedOpName("ConstantOfShape", kOnnxDomain),
{
{9, {0}}, // ignore the `input`, e.g. the shape of the expected output tensor
{20, {0}},
},
},
{
utils::GetFullQualifiedOpName("Cos", kOnnxDomain),
{
{7, {}},
},
},
{
utils::GetFullQualifiedOpName("CumSum", kOnnxDomain),
{
// The axis input is trivial
{11, {1}},
{14, {1}},
},
},
{
utils::GetFullQualifiedOpName("Dropout", kOnnxDomain),
{
// ONNX Dropout 1, 6, 7, 10 do not have seed attribute, so we remove them from the recompute support.
{12, {1, 2}}, // ignore ratio and training_mode
{13, {1, 2}},
},
},
{
utils::GetFullQualifiedOpName("Div", kOnnxDomain),
{
{1, {}},
{6, {}},
{7, {}},
{13, {}},
{14, {}},
},
},
{
utils::GetFullQualifiedOpName("Einsum", kOnnxDomain),
{
{12, {}},
},
},
{
utils::GetFullQualifiedOpName("Equal", kOnnxDomain),
{
{1, {}},
{7, {}},
{11, {}},
{13, {}},
{19, {}},
},
},
{
utils::GetFullQualifiedOpName("Expand", kOnnxDomain),
{
{8, {1}}, // Ignore the shape.
{13, {1}},
},
},
{
utils::GetFullQualifiedOpName("FastGelu", kMSDomain),
{
{1, {}},
},
},
{
utils::GetFullQualifiedOpName("FlattenAndUnpad", kMSDomain),
{
{1, {1}}, // ignore the indices
},
},
{
utils::GetFullQualifiedOpName("Gather", kOnnxDomain),
{
{1, {1}}, // ignore the indices
{11, {1}},
{13, {1}},
},
},
{
utils::GetFullQualifiedOpName("Gelu", kOnnxDomain),
{
{20, {}},
},
},
{
utils::GetFullQualifiedOpName("Gelu", kMSDomain),
{
{1, {}},
},
},
{
utils::GetFullQualifiedOpName("Gemm", kOnnxDomain),
{
{1, {}},
{6, {}},
{7, {}},
{9, {}},
{11, {}},
{13, {}},
},
},
{
utils::GetFullQualifiedOpName("Less", kOnnxDomain),
{
{1, {}},
{7, {}},
{9, {}},
{13, {}},
},
},
{
utils::GetFullQualifiedOpName("MemcpyFromHost", kOnnxDomain),
{
{1, {0}}, // Ignore CPU input.
},
},
{
utils::GetFullQualifiedOpName("Mul", kOnnxDomain),
{
{1, {}},
{6, {}},
{7, {}},
{13, {}},
{14, {}},
},
},
{
utils::GetFullQualifiedOpName("Neg", kOnnxDomain),
{
{1, {}},
{6, {}},
{13, {}},
},
},
{
utils::GetFullQualifiedOpName("NonZero", kOnnxDomain),
{
{9, {}},
{13, {}},
},
},
{
utils::GetFullQualifiedOpName("PadAndUnflatten", kMSDomain),
{
{1, {1, 2}}, // ignore the indices and unflatten_dims
},
},
{
// Be noted, NOT all PythonOp will be allowed to recompute, there will be further check.
utils::GetFullQualifiedOpName("PythonOp", kMSDomain),
{
{1, {}},
},
},
{
utils::GetFullQualifiedOpName("Range", kOnnxDomain),
{
{11, {0, 1, 2}}, // ignore start, end, delta, because they are scalars.
},
},
{
utils::GetFullQualifiedOpName("Reshape", kOnnxDomain),
{
{1, {}},
{5, {}}, // ignore the shape.
{13, {}},
{14, {}},
{19, {}},
},
},
{
utils::GetFullQualifiedOpName("Sin", kOnnxDomain),
{
{7, {}},
},
},
{
utils::GetFullQualifiedOpName("Slice", kOnnxDomain),
{
{1, {}},
{10, {1, 2, 3, 4}}, // ignore starts, ends, axes (optional) and steps (optional)
{11, {1, 2, 3, 4}},
{13, {1, 2, 3, 4}},
},
},
{
utils::GetFullQualifiedOpName("Split", kOnnxDomain),
{
{1, {1}}, // ignore split (optional)
{2, {}},
{11, {}},
{13, {1}}, // ignore the split (optional)
{18, {1}},
},
},
{
utils::GetFullQualifiedOpName("Squeeze", kOnnxDomain),
{
{1, {}},
{11, {}},
{13, {1}}, // ignore the axes (optional)
},
},
{
utils::GetFullQualifiedOpName("Sub", kOnnxDomain),
{
{1, {}},
{6, {}},
{7, {}},
{13, {}},
{14, {}},
},
},
{
utils::GetFullQualifiedOpName("Tile", kOnnxDomain),
{
{1, {1, 2}},
{6, {1}},
{13, {1}},
},
},
{
utils::GetFullQualifiedOpName("Transpose", kOnnxDomain),
{
{1, {}},
{13, {}},
},
},
{
utils::GetFullQualifiedOpName("Trilu", kOnnxDomain),
{
{14, {1}}, // ignore k (optional)
},
},
{
utils::GetFullQualifiedOpName("QuickGelu", kMSDomain),
{
{1, {}},
},
},
{
utils::GetFullQualifiedOpName("Unsqueeze", kOnnxDomain),
{
{1, {}},
{11, {}},
{13, {1}}, // ignore the axes (optional)
},
},
{
utils::GetFullQualifiedOpName("Where", kOnnxDomain),
{
{9, {}},
{16, {}},
},
},
});
constexpr const int advanced_op_level = static_cast<int>(ProbeLevel::Advanced);
recomputable_op_table_map.insert({advanced_op_level, InlinedHashMap<std::string, OpsetToIgnorableIndicesMap>()});
auto& advanced_recomputable_op_table = recomputable_op_table_map.at(advanced_op_level);
// Append basic_recomputable_op_table to advanced_recomputable_op_table.
advanced_recomputable_op_table.insert(recomputable_op_table_map.at(basic_op_level).begin(),
recomputable_op_table_map.at(basic_op_level).end());
advanced_recomputable_op_table.insert({
{
utils::GetFullQualifiedOpName("BiasSoftmax", kMSDomain),
{
{1, {}},
},
},
{
utils::GetFullQualifiedOpName("BiasSoftmaxDropout", kMSDomain),
{
{1, {2}}, // ignore ratio (optional)
},
},
{
utils::GetFullQualifiedOpName("LayerNormalization", kOnnxDomain),
{
// Opset 1 in ONNX official does not have LayerNormalization,
// while our contrib op defined LayerNormalization in opset 1 in ONNX domain.
{1, {}},
{17, {}},
},
},
{
utils::GetFullQualifiedOpName("MatMul", kOnnxDomain),
{
{1, {}},
{9, {}},
{13, {}},
},
},
{
utils::GetFullQualifiedOpName("FusedMatMul", kMSDomain),
{
{1, {}},
},
},
{
utils::GetFullQualifiedOpName("SimplifiedLayerNormalization", kOnnxDomain),
{
// Opset 1 in ONNX official does not have SimplifiedLayerNormalization,
// while our contrib op defined SimplifiedLayerNormalization in opset 1 in ONNX domain.
{1, {}},
},
},
{
utils::GetFullQualifiedOpName("SkipLayerNormalization", kMSDomain),
{
{1, {}},
},
},
{
utils::GetFullQualifiedOpName("SkipSimplifiedLayerNormalization", kMSDomain),
{
{1, {}},
},
},
{
utils::GetFullQualifiedOpName("Softmax", kOnnxDomain),
{
{1, {}},
{11, {}},
{13, {}},
},
},
});
return recomputable_op_table_map;
}
const InlinedHashMap<std::string, OpsetToIgnorableIndicesMap>& GetAllowedRecomputeOps(int probe_op_level) {
static InlinedHashMap<int, InlinedHashMap<std::string, OpsetToIgnorableIndicesMap>> recomputable_op_table_map;
if (recomputable_op_table_map.find(probe_op_level) != recomputable_op_table_map.end()) {
return recomputable_op_table_map.at(probe_op_level);
}
static InlinedHashMap<int, InlinedHashMap<std::string, OpsetToIgnorableIndicesMap>>
recomputable_op_table_map = InitializeRecomputableOpTable();
recomputable_op_table_map.insert({probe_op_level, InlinedHashMap<std::string, OpsetToIgnorableIndicesMap>()});
auto& recomputable_op_table = recomputable_op_table_map.at(probe_op_level);
if (probe_op_level >= static_cast<int>(ProbeLevel::Basic)) {
recomputable_op_table.insert({
{
utils::GetFullQualifiedOpName("Add", kOnnxDomain),
{
{1, {}},
{6, {}},
{7, {}},
{13, {}},
{14, {}},
},
},
{
utils::GetFullQualifiedOpName("BatchNormalization", kOnnxDomain),
{
{1, {}},
{6, {}},
{7, {}},
{9, {}},
{14, {}},
{15, {}},
},
},
{
utils::GetFullQualifiedOpName("BiasGelu", kMSDomain),
{
{1, {}},
},
},
{
utils::GetFullQualifiedOpName("BiasDropout", kMSDomain),
{
{1, {3, 4}}, // ignore ratio (optional) and training mode (optional)
},
},
{
utils::GetFullQualifiedOpName("BitmaskBiasDropout", kMSDomain),
{
{1, {3, 4}}, // ignore ratio (optional) and training mode (optional)
},
},
{
utils::GetFullQualifiedOpName("BitmaskDropout", kMSDomain),
{
{1, {1, 2}}, // ignore ratio (optional) and training mode (optional)
},
},
{
utils::GetFullQualifiedOpName("Cast", kOnnxDomain),
{
{1, {}},
{6, {}},
{9, {}},
{13, {}},
{19, {}},
},
},
{
utils::GetFullQualifiedOpName("ConcatTraining", kMSDomain),
{
{1, {}},
ORT_ENFORCE(recomputable_op_table_map.find(probe_op_level) != recomputable_op_table_map.end(),
"Cannot get recomputable op table, probe level: ", probe_op_level);
},
},
{
utils::GetFullQualifiedOpName("ConstantOfShape", kOnnxDomain),
{
{9, {0}}, // ignore the `input`, e.g. the shape of the expected output tensor
{20, {0}},
},
},
{
utils::GetFullQualifiedOpName("Cos", kOnnxDomain),
{
{7, {}},
},
},
{
utils::GetFullQualifiedOpName("CumSum", kOnnxDomain),
{
// The axis input is trivial
{11, {1}},
{14, {1}},
},
},
{
utils::GetFullQualifiedOpName("Dropout", kOnnxDomain),
{
// ONNX Dropout 1, 6, 7, 10 do not have seed attribute, so we remove them from the recompute support.
{12, {1, 2}}, // ignore ratio and training_mode
{13, {1, 2}},
},
},
{
utils::GetFullQualifiedOpName("Div", kOnnxDomain),
{
{1, {}},
{6, {}},
{7, {}},
{13, {}},
{14, {}},
},
},
{
utils::GetFullQualifiedOpName("Einsum", kOnnxDomain),
{
{12, {}},
},
},
{
utils::GetFullQualifiedOpName("Equal", kOnnxDomain),
{
{1, {}},
{7, {}},
{11, {}},
{13, {}},
{19, {}},
},
},
{
utils::GetFullQualifiedOpName("Expand", kOnnxDomain),
{
{8, {1}}, // Ignore the shape.
{13, {1}},
},
},
{
utils::GetFullQualifiedOpName("FastGelu", kMSDomain),
{
{1, {}},
},
},
{
utils::GetFullQualifiedOpName("FlattenAndUnpad", kMSDomain),
{
{1, {1}}, // ignore the indices
},
},
{
utils::GetFullQualifiedOpName("Gather", kOnnxDomain),
{
{1, {1}}, // ignore the indices
{11, {1}},
{13, {1}},
},
},
{
utils::GetFullQualifiedOpName("Gelu", kOnnxDomain),
{
{20, {}},
},
},
{
utils::GetFullQualifiedOpName("Gelu", kMSDomain),
{
{1, {}},
},
},
{
utils::GetFullQualifiedOpName("Gemm", kOnnxDomain),
{
{1, {}},
{6, {}},
{7, {}},
{9, {}},
{11, {}},
{13, {}},
},
},
{
utils::GetFullQualifiedOpName("Less", kOnnxDomain),
{
{1, {}},
{7, {}},
{9, {}},
{13, {}},
},
},
{
utils::GetFullQualifiedOpName("MemcpyFromHost", kOnnxDomain),
{
{1, {0}}, // Ignore CPU input.
},
},
{
utils::GetFullQualifiedOpName("Mul", kOnnxDomain),
{
{1, {}},
{6, {}},
{7, {}},
{13, {}},
{14, {}},
},
},
{
utils::GetFullQualifiedOpName("Neg", kOnnxDomain),
{
{1, {}},
{6, {}},
{13, {}},
},
},
{
utils::GetFullQualifiedOpName("NonZero", kOnnxDomain),
{
{9, {}},
{13, {}},
},
},
{
utils::GetFullQualifiedOpName("PadAndUnflatten", kMSDomain),
{
{1, {1, 2}}, // ignore the indices and unflatten_dims
},
},
{
// Be noted, NOT all PythonOp will be allowed to recompute, there will be further check.
utils::GetFullQualifiedOpName("PythonOp", kMSDomain),
{
{1, {}},
},
},
{
utils::GetFullQualifiedOpName("Range", kOnnxDomain),
{
{11, {0, 1, 2}}, // ignore start, end, delta, because they are scalars.
},
},
{
utils::GetFullQualifiedOpName("Reshape", kOnnxDomain),
{
{1, {}},
{5, {}}, // ignore the shape.
{13, {}},
{14, {}},
{19, {}},
},
},
{
utils::GetFullQualifiedOpName("Sin", kOnnxDomain),
{
{7, {}},
},
},
{
utils::GetFullQualifiedOpName("Slice", kOnnxDomain),
{
{1, {}},
{10, {1, 2, 3, 4}}, // ignore starts, ends, axes (optional) and steps (optional)
{11, {1, 2, 3, 4}},
{13, {1, 2, 3, 4}},
},
},
{
utils::GetFullQualifiedOpName("Split", kOnnxDomain),
{
{1, {1}}, // ignore split (optional)
{2, {}},
{11, {}},
{13, {1}}, // ignore the split (optional)
{18, {1}},
},
},
{
utils::GetFullQualifiedOpName("Squeeze", kOnnxDomain),
{
{1, {}},
{11, {}},
{13, {1}}, // ignore the axes (optional)
},
},
{
utils::GetFullQualifiedOpName("Sub", kOnnxDomain),
{
{1, {}},
{6, {}},
{7, {}},
{13, {}},
{14, {}},
},
},
{
utils::GetFullQualifiedOpName("Tile", kOnnxDomain),
{
{1, {1, 2}},
{6, {1}},
{13, {1}},
},
},
{
utils::GetFullQualifiedOpName("Transpose", kOnnxDomain),
{
{1, {}},
{13, {}},
},
},
{
utils::GetFullQualifiedOpName("Trilu", kOnnxDomain),
{
{14, {1}}, // ignore k (optional)
},
},
{
utils::GetFullQualifiedOpName("QuickGelu", kMSDomain),
{
{1, {}},
},
},
{
utils::GetFullQualifiedOpName("Unsqueeze", kOnnxDomain),
{
{1, {}},
{11, {}},
{13, {1}}, // ignore the axes (optional)
},
},
{
utils::GetFullQualifiedOpName("Where", kOnnxDomain),
{
{9, {}},
{16, {}},
},
},
});
}
if (probe_op_level >= static_cast<int>(ProbeLevel::Advanced)) {
recomputable_op_table.insert({
{
utils::GetFullQualifiedOpName("BiasSoftmax", kMSDomain),
{
{1, {}},
},
},
{
utils::GetFullQualifiedOpName("BiasSoftmaxDropout", kMSDomain),
{
{1, {2}}, // ignore ratio (optional)
},
},
{
utils::GetFullQualifiedOpName("LayerNormalization", kOnnxDomain),
{
// Opset 1 in ONNX official does not have LayerNormalization,
// while our contrib op defined LayerNormalization in opset 1 in ONNX domain.
{1, {}},
{17, {}},
},
},
{
utils::GetFullQualifiedOpName("MatMul", kOnnxDomain),
{
{1, {}},
{9, {}},
{13, {}},
},
},
{
utils::GetFullQualifiedOpName("FusedMatMul", kMSDomain),
{
{1, {}},
},
},
{
utils::GetFullQualifiedOpName("SimplifiedLayerNormalization", kOnnxDomain),
{
// Opset 1 in ONNX official does not have SimplifiedLayerNormalization,
// while our contrib op defined SimplifiedLayerNormalization in opset 1 in ONNX domain.
{1, {}},
},
},
{
utils::GetFullQualifiedOpName("SkipLayerNormalization", kMSDomain),
{
{1, {}},
},
},
{
utils::GetFullQualifiedOpName("SkipSimplifiedLayerNormalization", kMSDomain),
{
{1, {}},
},
},
{
utils::GetFullQualifiedOpName("Softmax", kOnnxDomain),
{
{1, {}},
{11, {}},
{13, {}},
},
},
});
}
return recomputable_op_table;
return recomputable_op_table_map.at(probe_op_level);
}
/**

View file

@ -542,8 +542,9 @@ TEST(TrainingApiTest, OptimStep) {
std::string param_name = "fc2.weight";
// before training, check if optim state is initialized to 0
onnxruntime::training::api::OptimizerCheckpointState& optimizer_states = state.optimizer_checkpoint_state;
std::shared_ptr<onnxruntime::training::api::GroupOptimizerState> group0_states = optimizer_states.group_named_optimizer_states["group0"];
onnxruntime::training::api::ParameterOptimizerState& param_state =
optimizer_states.group_named_optimizer_states["group0"]->param_named_optimizer_states.at(param_name);
group0_states->param_named_optimizer_states.at(param_name);
OrtValue& moment_1 = param_state.at("momentum0");
std::vector<float> param_vec_before_optimizer_step;

View file

@ -449,7 +449,7 @@ Status FromOptimizerState(const OptimizerCheckpointState& optimizer_state,
fbs_optimizer_groups.reserve(optimizer_state.group_named_optimizer_states.size());
for (const auto& group_name : SortedKeys(optimizer_state.group_named_optimizer_states)) {
const std::shared_ptr<GroupOptimizerState>& group_optimizer_state_ptr =
std::shared_ptr<GroupOptimizerState> group_optimizer_state_ptr =
optimizer_state.group_named_optimizer_states.at(group_name);
std::vector<flatbuffers::Offset<fbs::ParameterOptimizerState>> optimizer_states;