Update maximum batch size for UT; Include recompute modes (#5444)

* Update MaxBatchSize and include recompute mode
* Minor fix for frontend test

Co-authored-by: Sherlock Huang <bahuang@OrtTrainingDev3.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
This commit is contained in:
Sherlock 2020-10-12 14:50:43 -07:00 committed by GitHub
parent dbc626dcbe
commit 60dbd8a1e5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 30 additions and 12 deletions

View file

@ -25,10 +25,12 @@ bool GeluRecompute::SatisfyCondition(const Node& node) const {
Status GeluRecompute::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*/, const logging::Logger& /*logger*/) const {
GraphViewer graph_viewer(graph);
const auto& order = graph_viewer.GetNodesInTopologicalOrder();
const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder();
for (NodeIndex i : order) {
Node& node = *graph.GetNode(i);
// Traverse backward from the bottom of the graph, so that the recompute nodes
// for lower layers are executed earlier
for (int i = static_cast<int>(node_ids.size() - 1); i >= 0; --i) {
Node& node = *graph.GetNode(node_ids[i]);
if (!SatisfyCondition(node)) {
continue;
@ -70,10 +72,12 @@ bool AttentionDropoutRecompute::SatisfyCondition(const Node& node) const {
Status AttentionDropoutRecompute::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*/, const logging::Logger& /*logger*/) const {
GraphViewer graph_viewer(graph);
const auto& order = graph_viewer.GetNodesInTopologicalOrder();
const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder();
for (NodeIndex i : order) {
Node& node = *graph.GetNode(i);
// Traverse backward from the bottom of the graph, so that the recompute nodes
// for lower layers are executed earlier
for (int i = static_cast<int>(node_ids.size() - 1); i >= 0; --i) {
Node& node = *graph.GetNode(node_ids[i]);
if (!SatisfyCondition(node)) {
continue;

View file

@ -764,7 +764,6 @@ def testORTTrainerRecompute(attn_dropout, gelu, transformer_layer, number_layers
set_seed(seed)
# Setup ORTTrainer
loss_scaler = amp.DynamicLossScaler()
options = orttrainer.ORTTrainerOptions({'device' : {'id' : device},
'graph_transformer' : {
'attn_dropout_recompute': attn_dropout,

View file

@ -19,12 +19,26 @@ def parse_args():
def main():
args = parse_args()
Config = collections.namedtuple("Config", ["enable_mixed_precision", "sequence_length", "max_batch_size"])
Config = collections.namedtuple("Config", ["enable_mixed_precision",
"sequence_length",
"max_batch_size",
"max_predictions_per_seq",
"additional_options"])
configs = [
Config(True, 128, 66),
Config(True, 512, 10),
Config(False, 128, 33),
Config(False, 512, 5),
Config(True, 128, 76, 20, ""),
Config(True, 512, 11, 80, ""),
Config(False, 128, 39, 20, ""),
Config(False, 512, 6, 80, ""),
# BertLarge Phase 1 recompute
Config(True, 128, 91, 20, "--gelu_recompute"),
Config(True, 128, 83, 20, "--attn_dropout_recompute"),
Config(True, 128, 344, 20, "--transformer_layer_recompute"),
# BertLarge Phase 2 recompute
Config(True, 512, 12, 80, "--gelu_recompute"),
Config(True, 512, 14, 80, "--attn_dropout_recompute"),
Config(True, 512, 50, 80, "--transformer_layer_recompute"),
]
# run BERT training
@ -52,6 +66,7 @@ def main():
"--use_nccl",
"--seed", "42",
"--enable_grad_norm_clip=false",
config.additional_options
]
if config.enable_mixed_precision: