mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-03 23:49:44 +00:00
Update maximum batch size for UT; Include recompute modes (#5444)
* Update MaxBatchSize and include recompute mode * Minor fix for frontend test Co-authored-by: Sherlock Huang <bahuang@OrtTrainingDev3.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
This commit is contained in:
parent
dbc626dcbe
commit
60dbd8a1e5
3 changed files with 30 additions and 12 deletions
|
|
@ -25,10 +25,12 @@ bool GeluRecompute::SatisfyCondition(const Node& node) const {
|
|||
|
||||
Status GeluRecompute::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*/, const logging::Logger& /*logger*/) const {
|
||||
GraphViewer graph_viewer(graph);
|
||||
const auto& order = graph_viewer.GetNodesInTopologicalOrder();
|
||||
const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder();
|
||||
|
||||
for (NodeIndex i : order) {
|
||||
Node& node = *graph.GetNode(i);
|
||||
// Traverse backward from the bottom of the graph, so that the recompute nodes
|
||||
// for lower layers are executed earlier
|
||||
for (int i = static_cast<int>(node_ids.size() - 1); i >= 0; --i) {
|
||||
Node& node = *graph.GetNode(node_ids[i]);
|
||||
|
||||
if (!SatisfyCondition(node)) {
|
||||
continue;
|
||||
|
|
@ -70,10 +72,12 @@ bool AttentionDropoutRecompute::SatisfyCondition(const Node& node) const {
|
|||
|
||||
Status AttentionDropoutRecompute::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*/, const logging::Logger& /*logger*/) const {
|
||||
GraphViewer graph_viewer(graph);
|
||||
const auto& order = graph_viewer.GetNodesInTopologicalOrder();
|
||||
const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder();
|
||||
|
||||
for (NodeIndex i : order) {
|
||||
Node& node = *graph.GetNode(i);
|
||||
// Traverse backward from the bottom of the graph, so that the recompute nodes
|
||||
// for lower layers are executed earlier
|
||||
for (int i = static_cast<int>(node_ids.size() - 1); i >= 0; --i) {
|
||||
Node& node = *graph.GetNode(node_ids[i]);
|
||||
|
||||
if (!SatisfyCondition(node)) {
|
||||
continue;
|
||||
|
|
|
|||
|
|
@ -764,7 +764,6 @@ def testORTTrainerRecompute(attn_dropout, gelu, transformer_layer, number_layers
|
|||
set_seed(seed)
|
||||
|
||||
# Setup ORTTrainer
|
||||
loss_scaler = amp.DynamicLossScaler()
|
||||
options = orttrainer.ORTTrainerOptions({'device' : {'id' : device},
|
||||
'graph_transformer' : {
|
||||
'attn_dropout_recompute': attn_dropout,
|
||||
|
|
|
|||
|
|
@ -19,12 +19,26 @@ def parse_args():
|
|||
def main():
|
||||
args = parse_args()
|
||||
|
||||
Config = collections.namedtuple("Config", ["enable_mixed_precision", "sequence_length", "max_batch_size"])
|
||||
Config = collections.namedtuple("Config", ["enable_mixed_precision",
|
||||
"sequence_length",
|
||||
"max_batch_size",
|
||||
"max_predictions_per_seq",
|
||||
"additional_options"])
|
||||
configs = [
|
||||
Config(True, 128, 66),
|
||||
Config(True, 512, 10),
|
||||
Config(False, 128, 33),
|
||||
Config(False, 512, 5),
|
||||
Config(True, 128, 76, 20, ""),
|
||||
Config(True, 512, 11, 80, ""),
|
||||
Config(False, 128, 39, 20, ""),
|
||||
Config(False, 512, 6, 80, ""),
|
||||
|
||||
# BertLarge Phase 1 recompute
|
||||
Config(True, 128, 91, 20, "--gelu_recompute"),
|
||||
Config(True, 128, 83, 20, "--attn_dropout_recompute"),
|
||||
Config(True, 128, 344, 20, "--transformer_layer_recompute"),
|
||||
|
||||
# BertLarge Phase 2 recompute
|
||||
Config(True, 512, 12, 80, "--gelu_recompute"),
|
||||
Config(True, 512, 14, 80, "--attn_dropout_recompute"),
|
||||
Config(True, 512, 50, 80, "--transformer_layer_recompute"),
|
||||
]
|
||||
|
||||
# run BERT training
|
||||
|
|
@ -52,6 +66,7 @@ def main():
|
|||
"--use_nccl",
|
||||
"--seed", "42",
|
||||
"--enable_grad_norm_clip=false",
|
||||
config.additional_options
|
||||
]
|
||||
|
||||
if config.enable_mixed_precision:
|
||||
|
|
|
|||
Loading…
Reference in a new issue